From e2f074e16c522bfa55d9bcd344a5ea0ba5a4580f Mon Sep 17 00:00:00 2001 From: Channing Conger Date: Tue, 16 Jun 2026 19:28:55 -0700 Subject: [PATCH] code-mode: move cell state into library actor (#28599) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A code-mode cell is a single JavaScript execution that can produce output, call tools, wait for asynchronous work, resume, or be terminated. This PR extracts the existing per-cell run loop into a dedicated actor that owns the cell’s lifecycle state. It is primarily an ownership change rather than a new lifecycle contract: existing behavior now has one clear implementation boundary. ### Architecture The session service remains responsible for session-wide concerns: allocating cell IDs, storing shared values, creating cells, and routing requests to them. Once a cell is created, its execution state belongs to its actor. Callers interact with the actor through a handle. The actor receives two kinds of input: runtime events and control requests. A single event loop serializes these inputs and applies the lifecycle rules. It tracks the current observer—the caller waiting for an update—along with accumulated output, outstanding callbacks, runtime state, yield deadlines, and termination progress. Observation, termination, completion, and cleanup therefore have one consistent owner. When the runtime has no immediately runnable work and is waiting only on timers or tool results, the actor can return accumulated output and information about outstanding tool calls while keeping the cell available to resume. On completion or termination, it performs the appropriate callback cleanup before publishing the final result and removing the cell from the session. A small host interface connects the actor to session-owned facilities such as tool dispatch, notifications, stored values, and final cell removal, keeping those responsibilities outside the actor itself. ### Why Previously, cell lifecycle state and coordination lived alongside session management. The actor boundary makes each cell a self-contained state machine with a single writer, while the service becomes a registry and adapter around it. This makes lifecycle behavior easier to reason about and test in isolation. It also establishes a clean boundary for later changing where cells run or how they communicate without recreating their lifecycle rules. --- .../code-mode/src/cell_actor/callbacks.rs | 77 + .../code-mode/src/cell_actor/conversions.rs | 60 + codex-rs/code-mode/src/cell_actor/mod.rs | 476 ++++ codex-rs/code-mode/src/cell_actor/tests.rs | 143 ++ codex-rs/code-mode/src/cell_actor/types.rs | 203 ++ codex-rs/code-mode/src/lib.rs | 1 + codex-rs/code-mode/src/runtime/mod.rs | 6 + codex-rs/code-mode/src/service.rs | 1927 ++--------------- .../code-mode/src/service_contract_tests.rs | 301 +-- codex-rs/code-mode/src/service_tests.rs | 972 +++++++++ 10 files changed, 2261 insertions(+), 1905 deletions(-) create mode 100644 codex-rs/code-mode/src/cell_actor/callbacks.rs create mode 100644 codex-rs/code-mode/src/cell_actor/conversions.rs create mode 100644 codex-rs/code-mode/src/cell_actor/mod.rs create mode 100644 codex-rs/code-mode/src/cell_actor/tests.rs create mode 100644 codex-rs/code-mode/src/cell_actor/types.rs create mode 100644 codex-rs/code-mode/src/service_tests.rs diff --git a/codex-rs/code-mode/src/cell_actor/callbacks.rs b/codex-rs/code-mode/src/cell_actor/callbacks.rs new file mode 100644 index 000000000..9db66ed3a --- /dev/null +++ b/codex-rs/code-mode/src/cell_actor/callbacks.rs @@ -0,0 +1,77 @@ +use std::sync::Arc; + +use tokio::task::JoinSet; +use tokio_util::sync::CancellationToken; +use tracing::warn; + +use super::CellHost; +use super::CellToolCall; +use crate::runtime::RuntimeCommand; + +#[derive(Clone, Copy)] +pub(super) enum CallbackCompletion { + DrainNotifications, + Cancel, +} + +pub(super) fn spawn_notification( + tasks: &mut JoinSet<()>, + host: Arc, + call_id: String, + text: String, + cancellation_token: CancellationToken, +) { + tasks.spawn(async move { + if let Err(err) = host.notify(call_id, text, cancellation_token).await { + warn!("failed to deliver code mode notification: {err}"); + } + }); +} + +pub(super) fn spawn_tool( + tasks: &mut JoinSet<()>, + host: Arc, + invocation: CellToolCall, + runtime_tx: std::sync::mpsc::Sender, + cancellation_token: CancellationToken, +) { + tasks.spawn(async move { + let id = invocation.id.clone(); + let command = match host.invoke_tool(invocation, cancellation_token).await { + Ok(result) => RuntimeCommand::ToolResponse { id, result }, + Err(error_text) => RuntimeCommand::ToolError { id, error_text }, + }; + let _ = runtime_tx.send(command); + }); +} + +pub(super) async fn finish_callbacks( + cancellation_token: &CancellationToken, + notification_tasks: &mut JoinSet<()>, + tool_tasks: &mut JoinSet<()>, + completion: CallbackCompletion, +) { + if matches!(completion, CallbackCompletion::Cancel) { + cancellation_token.cancel(); + } + drain_tasks(notification_tasks, "notification").await; + cancellation_token.cancel(); + drain_tasks(tool_tasks, "tool").await; +} + +pub(super) fn log_task_result( + task_result: Option>, + description: &str, +) { + if let Some(Err(err)) = task_result + && !err.is_cancelled() + { + warn!("code mode {description} task failed: {err}"); + } +} + +async fn drain_tasks(tasks: &mut JoinSet<()>, description: &str) { + while let Some(result) = tasks.join_next().await { + log_task_result(Some(result), description); + } +} diff --git a/codex-rs/code-mode/src/cell_actor/conversions.rs b/codex-rs/code-mode/src/cell_actor/conversions.rs new file mode 100644 index 000000000..ebc12d29f --- /dev/null +++ b/codex-rs/code-mode/src/cell_actor/conversions.rs @@ -0,0 +1,60 @@ +use codex_code_mode_protocol::CodeModeToolKind; +use codex_code_mode_protocol::ExecuteRequest; +use codex_code_mode_protocol::FunctionCallOutputContentItem; +use codex_code_mode_protocol::ImageDetail; +use codex_code_mode_protocol::ToolDefinition; +use codex_protocol::ToolName; + +use super::CellImageDetail; +use super::CellOutputItem; +use super::CellRequest; +use super::CellToolKind; + +pub(super) fn runtime_request(request: CellRequest) -> ExecuteRequest { + ExecuteRequest { + tool_call_id: request.tool_call_id, + enabled_tools: request + .enabled_tools + .into_iter() + .map(|definition| ToolDefinition { + name: definition.name, + tool_name: ToolName { + name: definition.tool_name.name, + namespace: definition.tool_name.namespace, + }, + description: definition.description, + kind: match definition.kind { + CellToolKind::Function => CodeModeToolKind::Function, + CellToolKind::Freeform => CodeModeToolKind::Freeform, + }, + input_schema: None, + output_schema: None, + }) + .collect(), + source: request.source, + yield_time_ms: None, + max_output_tokens: None, + } +} + +pub(super) fn cell_tool_kind(kind: CodeModeToolKind) -> CellToolKind { + match kind { + CodeModeToolKind::Function => CellToolKind::Function, + CodeModeToolKind::Freeform => CellToolKind::Freeform, + } +} + +pub(super) fn output_item(item: FunctionCallOutputContentItem) -> CellOutputItem { + match item { + FunctionCallOutputContentItem::InputText { text } => CellOutputItem::Text { text }, + FunctionCallOutputContentItem::InputImage { image_url, detail } => CellOutputItem::Image { + image_url, + detail: detail.map(|detail| match detail { + ImageDetail::Auto => CellImageDetail::Auto, + ImageDetail::Low => CellImageDetail::Low, + ImageDetail::High => CellImageDetail::High, + ImageDetail::Original => CellImageDetail::Original, + }), + }, + } +} diff --git a/codex-rs/code-mode/src/cell_actor/mod.rs b/codex-rs/code-mode/src/cell_actor/mod.rs new file mode 100644 index 000000000..d42ce48a2 --- /dev/null +++ b/codex-rs/code-mode/src/cell_actor/mod.rs @@ -0,0 +1,476 @@ +mod callbacks; +mod conversions; +mod types; + +use std::collections::HashMap; +use std::future::Future; +use std::sync::Arc; + +use serde_json::Value as JsonValue; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::task::JoinSet; +use tokio_util::sync::CancellationToken; + +use self::callbacks::CallbackCompletion; +use self::callbacks::finish_callbacks; +use self::callbacks::log_task_result; +use self::callbacks::spawn_notification; +use self::callbacks::spawn_tool; +use self::conversions::cell_tool_kind; +use self::conversions::output_item; +use self::conversions::runtime_request; +use self::types::CellCommand; +pub(crate) use self::types::CellError; +pub(crate) use self::types::CellEvent; +pub(crate) use self::types::CellEventFuture; +pub(crate) use self::types::CellHandle; +pub(crate) use self::types::CellHost; +pub(crate) use self::types::CellImageDetail; +pub(crate) use self::types::CellOutputItem; +pub(crate) use self::types::CellRequest; +pub(crate) use self::types::CellToolCall; +pub(crate) use self::types::CellToolDefinition; +pub(crate) use self::types::CellToolKind; +pub(crate) use self::types::CellToolName; +pub(crate) use self::types::ObserveMode; +use crate::runtime::PendingRuntimeMode; +use crate::runtime::RuntimeCommand; +use crate::runtime::RuntimeControlCommand; +use crate::runtime::RuntimeEvent; +use crate::runtime::spawn_runtime; + +pub(crate) struct CellActor; + +impl CellActor { + pub(crate) fn prepare( + request: CellRequest, + stored_values: HashMap, + host: Arc, + initial_observe_mode: ObserveMode, + ) -> Result< + ( + CellHandle, + CellEventFuture, + impl Future + Send + 'static, + ), + String, + > { + let (event_tx, event_rx) = mpsc::unbounded_channel(); + let (command_tx, command_rx) = mpsc::unbounded_channel(); + let (initial_response_tx, initial_response_rx) = oneshot::channel(); + let (runtime_tx, runtime_control_tx, runtime_terminate_handle) = spawn_runtime( + stored_values, + runtime_request(request), + event_tx, + PendingRuntimeMode::PauseUntilResumed, + )?; + let cancellation_token = CancellationToken::new(); + let handle = CellHandle::new(command_tx, cancellation_token.clone()); + let task = run_cell( + host, + CellContext { + runtime_tx, + runtime_control_tx, + runtime_terminate_handle, + cancellation_token, + }, + event_rx, + command_rx, + Observer { + mode: initial_observe_mode, + response_tx: initial_response_tx, + }, + ); + let initial_response = + Box::pin(async move { initial_response_rx.await.unwrap_or(Err(CellError::Closed)) }); + Ok((handle, initial_response, task)) + } +} + +struct CellContext { + runtime_tx: std::sync::mpsc::Sender, + runtime_control_tx: std::sync::mpsc::Sender, + runtime_terminate_handle: v8::IsolateHandle, + cancellation_token: CancellationToken, +} + +struct Observer { + mode: ObserveMode, + response_tx: oneshot::Sender>, +} + +struct Termination { + response_tx: Option>>, +} + +async fn run_cell( + host: Arc, + context: CellContext, + mut event_rx: mpsc::UnboundedReceiver, + mut command_rx: mpsc::UnboundedReceiver, + initial_observer: Observer, +) { + let CellContext { + runtime_tx, + runtime_control_tx, + runtime_terminate_handle, + cancellation_token, + } = context; + let mut content_items = Vec::new(); + let mut pending_tool_call_ids = Vec::new(); + let mut completed_event = None; + let mut observer = Some(initial_observer); + let mut termination: Option = None; + let mut runtime_closed = false; + let mut runtime_paused = false; + let mut yield_timer: Option>> = None; + let mut notification_tasks = JoinSet::new(); + let mut tool_tasks = JoinSet::new(); + + loop { + let yield_deadline_elapsed = yield_timer + .as_ref() + .is_some_and(|yield_timer| yield_timer.deadline() <= tokio::time::Instant::now()); + tokio::select! { + biased; + maybe_command = command_rx.recv() => { + let Some(command) = maybe_command else { + if completed_event.is_some() { + break; + } + termination = Some(Termination { response_tx: None }); + begin_termination( + &runtime_tx, + &runtime_control_tx, + &runtime_terminate_handle, + &cancellation_token, + ); + if runtime_closed { + break; + } + continue; + }; + match command { + CellCommand::Observe { mode, response_tx } => { + if let Some(event) = completed_event.take() { + let _ = response_tx.send(Ok(event)); + break; + } + if observer.is_some() || termination.is_some() { + let _ = response_tx.send(Err(CellError::Busy)); + continue; + } + observer = Some(Observer { mode, response_tx }); + yield_timer = observer.as_ref().and_then(observer_timer); + resume_for_observation( + mode, + &mut runtime_paused, + &runtime_tx, + &runtime_control_tx, + ); + } + CellCommand::Terminate { response_tx } => { + if let Some(event) = completed_event.take() { + if let Some(response_tx) = response_tx { + let _ = response_tx.send(Ok(event)); + } + break; + } + if termination.is_some() { + if let Some(response_tx) = response_tx { + let _ = response_tx.send(Err(CellError::AlreadyTerminating)); + } + continue; + } + termination = Some(Termination { response_tx }); + yield_timer = None; + begin_termination( + &runtime_tx, + &runtime_control_tx, + &runtime_terminate_handle, + &cancellation_token, + ); + if runtime_closed { + finish_callbacks( + &cancellation_token, + &mut notification_tasks, + &mut tool_tasks, + CallbackCompletion::Cancel, + ).await; + send_termination_events( + observer.take(), + termination.take(), + CellEvent::Terminated { + content_items: std::mem::take(&mut content_items), + }, + ); + break; + } + } + } + } + _ = async { + if let Some(yield_timer) = yield_timer.as_mut() { + yield_timer.await; + } else { + std::future::pending::<()>().await; + } + } => { + yield_timer = None; + send_observer_event( + observer.take(), + CellEvent::Yielded { + content_items: std::mem::take(&mut content_items), + }, + ); + } + maybe_event = async { + if runtime_closed { + std::future::pending::>().await + } else { + event_rx.recv().await + } + }, if !yield_deadline_elapsed => { + let Some(event) = maybe_event else { + runtime_closed = true; + if termination.is_some() { + finish_callbacks( + &cancellation_token, + &mut notification_tasks, + &mut tool_tasks, + CallbackCompletion::Cancel, + ).await; + send_termination_events( + observer.take(), + termination.take(), + CellEvent::Terminated { + content_items: std::mem::take(&mut content_items), + }, + ); + break; + } + if completed_event.is_none() { + finish_callbacks( + &cancellation_token, + &mut notification_tasks, + &mut tool_tasks, + CallbackCompletion::DrainNotifications, + ).await; + let event = CellEvent::Completed { + content_items: std::mem::take(&mut content_items), + error_text: Some("exec runtime ended unexpectedly".to_string()), + }; + if send_or_buffer_completion( + event, + &mut observer, + &mut completed_event, + ) { + break; + } + } + continue; + }; + match event { + RuntimeEvent::Started => { + yield_timer = observer.as_ref().and_then(observer_timer); + } + RuntimeEvent::Pending => { + runtime_paused = true; + if matches!( + observer.as_ref().map(|observer| observer.mode), + Some(ObserveMode::PendingFrontier) + ) { + yield_timer = None; + send_observer_event( + observer.take(), + CellEvent::Pending { + content_items: std::mem::take(&mut content_items), + pending_tool_call_ids: std::mem::take( + &mut pending_tool_call_ids, + ), + }, + ); + } else { + pending_tool_call_ids.clear(); + let _ = runtime_control_tx.send(RuntimeControlCommand::Continue); + runtime_paused = false; + } + } + RuntimeEvent::ContentItem(item) => content_items.push(output_item(item)), + RuntimeEvent::YieldRequested => { + if matches!( + observer.as_ref().map(|observer| observer.mode), + Some(ObserveMode::YieldAfter(_)) + ) { + yield_timer = None; + send_observer_event( + observer.take(), + CellEvent::Yielded { + content_items: std::mem::take(&mut content_items), + }, + ); + } + } + RuntimeEvent::Notify { call_id, text } => { + spawn_notification( + &mut notification_tasks, + Arc::clone(&host), + call_id, + text, + cancellation_token.child_token(), + ); + } + RuntimeEvent::ToolCall { id, name, kind, input } => { + pending_tool_call_ids.push(id.clone()); + spawn_tool( + &mut tool_tasks, + Arc::clone(&host), + CellToolCall { + id, + name: CellToolName { + name: name.name, + namespace: name.namespace, + }, + kind: cell_tool_kind(kind), + input, + }, + runtime_tx.clone(), + cancellation_token.child_token(), + ); + } + RuntimeEvent::Result { stored_value_writes, error_text } => { + runtime_closed = true; + yield_timer = None; + if termination.is_some() { + finish_callbacks( + &cancellation_token, + &mut notification_tasks, + &mut tool_tasks, + CallbackCompletion::Cancel, + ).await; + send_termination_events( + observer.take(), + termination.take(), + CellEvent::Terminated { + content_items: std::mem::take(&mut content_items), + }, + ); + break; + } + finish_callbacks( + &cancellation_token, + &mut notification_tasks, + &mut tool_tasks, + CallbackCompletion::DrainNotifications, + ).await; + host.commit_stored_values(stored_value_writes).await; + let event = CellEvent::Completed { + content_items: std::mem::take(&mut content_items), + error_text, + }; + if send_or_buffer_completion( + event, + &mut observer, + &mut completed_event, + ) { + break; + } + } + } + } + task_result = notification_tasks.join_next(), if !notification_tasks.is_empty() => { + log_task_result(task_result, "notification"); + } + task_result = tool_tasks.join_next(), if !tool_tasks.is_empty() => { + log_task_result(task_result, "tool"); + } + } + } + begin_termination( + &runtime_tx, + &runtime_control_tx, + &runtime_terminate_handle, + &cancellation_token, + ); + finish_callbacks( + &cancellation_token, + &mut notification_tasks, + &mut tool_tasks, + CallbackCompletion::Cancel, + ) + .await; + host.closed().await; +} + +fn send_or_buffer_completion( + event: CellEvent, + observer: &mut Option, + completed_event: &mut Option, +) -> bool { + if observer.is_some() { + send_observer_event(observer.take(), event); + true + } else { + *completed_event = Some(event); + false + } +} + +fn send_observer_event(observer: Option, event: CellEvent) { + if let Some(observer) = observer { + let _ = observer.response_tx.send(Ok(event)); + } +} + +fn send_termination_events( + observer: Option, + termination: Option, + event: CellEvent, +) { + send_observer_event(observer, event.clone()); + if let Some(response_tx) = termination.and_then(|termination| termination.response_tx) { + let _ = response_tx.send(Ok(event)); + } +} + +fn observer_timer(observer: &Observer) -> Option>> { + match observer.mode { + ObserveMode::YieldAfter(duration) => Some(Box::pin(tokio::time::sleep(duration))), + ObserveMode::PendingFrontier => None, + } +} + +fn resume_for_observation( + mode: ObserveMode, + runtime_paused: &mut bool, + runtime_tx: &std::sync::mpsc::Sender, + runtime_control_tx: &std::sync::mpsc::Sender, +) { + if *runtime_paused { + let control = match mode { + ObserveMode::YieldAfter(_) => RuntimeControlCommand::Continue, + ObserveMode::PendingFrontier => RuntimeControlCommand::Resume, + }; + let _ = runtime_control_tx.send(control); + *runtime_paused = false; + } else if matches!(mode, ObserveMode::PendingFrontier) { + let _ = runtime_tx.send(RuntimeCommand::ObservePendingFrontier); + } +} + +fn begin_termination( + runtime_tx: &std::sync::mpsc::Sender, + runtime_control_tx: &std::sync::mpsc::Sender, + runtime_terminate_handle: &v8::IsolateHandle, + cancellation_token: &CancellationToken, +) { + cancellation_token.cancel(); + let _ = runtime_tx.send(RuntimeCommand::Terminate); + let _ = runtime_control_tx.send(RuntimeControlCommand::Terminate); + let _ = runtime_terminate_handle.terminate_execution(); +} + +#[cfg(test)] +#[path = "tests.rs"] +mod tests; diff --git a/codex-rs/code-mode/src/cell_actor/tests.rs b/codex-rs/code-mode/src/cell_actor/tests.rs new file mode 100644 index 000000000..218f1fcde --- /dev/null +++ b/codex-rs/code-mode/src/cell_actor/tests.rs @@ -0,0 +1,143 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use codex_code_mode_protocol::ExecuteRequest; +use codex_code_mode_protocol::FunctionCallOutputContentItem; +use pretty_assertions::assert_eq; +use serde_json::Value as JsonValue; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio_util::sync::CancellationToken; + +use super::*; + +struct TestHost; + +impl CellHost for TestHost { + async fn invoke_tool( + &self, + _invocation: CellToolCall, + _cancellation_token: CancellationToken, + ) -> Result { + Err("unexpected tool call".to_string()) + } + + async fn notify( + &self, + _call_id: String, + _text: String, + _cancellation_token: CancellationToken, + ) -> Result<(), String> { + Ok(()) + } + + async fn commit_stored_values(&self, _stored_value_writes: HashMap) {} + + async fn closed(&self) {} +} + +struct CellActorHarness { + event_tx: mpsc::UnboundedSender, + handle: CellHandle, + initial_event_rx: oneshot::Receiver>, + task: tokio::task::JoinHandle<()>, + _runtime_event_rx: mpsc::UnboundedReceiver, +} + +fn spawn_cell_actor_harness(initial_observe_mode: ObserveMode) -> CellActorHarness { + let (event_tx, event_rx) = mpsc::unbounded_channel(); + let (command_tx, command_rx) = mpsc::unbounded_channel(); + let (initial_event_tx, initial_event_rx) = oneshot::channel(); + let (runtime_event_tx, runtime_event_rx) = mpsc::unbounded_channel(); + let (runtime_tx, runtime_control_tx, runtime_terminate_handle) = spawn_runtime( + HashMap::new(), + ExecuteRequest { + tool_call_id: "call-1".to_string(), + enabled_tools: Vec::new(), + source: "await new Promise(() => {});".to_string(), + yield_time_ms: None, + max_output_tokens: None, + }, + runtime_event_tx, + PendingRuntimeMode::PauseUntilResumed, + ) + .unwrap(); + let handle = CellHandle::new(command_tx, CancellationToken::new()); + let task = tokio::spawn(run_cell( + Arc::new(TestHost), + CellContext { + runtime_tx, + runtime_control_tx, + runtime_terminate_handle, + cancellation_token: CancellationToken::new(), + }, + event_rx, + command_rx, + Observer { + mode: initial_observe_mode, + response_tx: initial_event_tx, + }, + )); + + CellActorHarness { + event_tx, + handle, + initial_event_rx, + task, + _runtime_event_rx: runtime_event_rx, + } +} + +#[tokio::test] +async fn yield_timer_preempts_buffered_runtime_output() { + let harness = spawn_cell_actor_harness(ObserveMode::YieldAfter(Duration::ZERO)); + harness.event_tx.send(RuntimeEvent::Started).unwrap(); + harness + .event_tx + .send(RuntimeEvent::ContentItem( + FunctionCallOutputContentItem::InputText { + text: "queued output".to_string(), + }, + )) + .unwrap(); + + assert_eq!( + harness.initial_event_rx.await.unwrap(), + Ok(CellEvent::Yielded { + content_items: Vec::new(), + }) + ); + + let termination = harness.handle.terminate(); + drop(harness.event_tx); + assert_eq!( + termination.await, + Ok(CellEvent::Terminated { + content_items: vec![CellOutputItem::Text { + text: "queued output".to_string(), + }], + }) + ); + harness.task.await.unwrap(); +} + +#[tokio::test] +async fn queued_termination_preempts_unobserved_runtime_completion() { + let harness = spawn_cell_actor_harness(ObserveMode::YieldAfter(Duration::from_secs(60))); + harness + .event_tx + .send(RuntimeEvent::Result { + stored_value_writes: HashMap::new(), + error_text: None, + }) + .unwrap(); + let termination = harness.handle.terminate(); + + let terminated = Ok(CellEvent::Terminated { + content_items: Vec::new(), + }); + assert_eq!(termination.await, terminated.clone()); + assert_eq!(harness.initial_event_rx.await.unwrap(), terminated); + harness.task.await.unwrap(); +} diff --git a/codex-rs/code-mode/src/cell_actor/types.rs b/codex-rs/code-mode/src/cell_actor/types.rs new file mode 100644 index 000000000..aef0b5211 --- /dev/null +++ b/codex-rs/code-mode/src/cell_actor/types.rs @@ -0,0 +1,203 @@ +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::time::Duration; + +use serde_json::Value as JsonValue; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio_util::sync::CancellationToken; + +pub(crate) type CellEventFuture = + Pin> + Send + 'static>>; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum ObserveMode { + YieldAfter(Duration), + PendingFrontier, +} + +#[derive(Clone, Debug, PartialEq)] +pub(crate) enum CellEvent { + Yielded { + content_items: Vec, + }, + Pending { + content_items: Vec, + pending_tool_call_ids: Vec, + }, + Completed { + content_items: Vec, + error_text: Option, + }, + Terminated { + content_items: Vec, + }, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum CellError { + Busy, + AlreadyTerminating, + Closed, +} + +#[derive(Clone, Debug, PartialEq)] +pub(crate) enum CellOutputItem { + Text { + text: String, + }, + Image { + image_url: String, + detail: Option, + }, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum CellImageDetail { + Auto, + Low, + High, + Original, +} + +pub(crate) struct CellRequest { + pub(crate) tool_call_id: String, + pub(crate) enabled_tools: Vec, + pub(crate) source: String, +} + +pub(crate) struct CellToolDefinition { + pub(crate) name: String, + pub(crate) tool_name: CellToolName, + pub(crate) description: String, + pub(crate) kind: CellToolKind, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub(crate) struct CellToolName { + pub(crate) name: String, + pub(crate) namespace: Option, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum CellToolKind { + Function, + Freeform, +} + +pub(crate) struct CellToolCall { + pub(crate) id: String, + pub(crate) name: CellToolName, + pub(crate) kind: CellToolKind, + pub(crate) input: Option, +} + +/// Connects a cell actor to session-owned callbacks and lifecycle state. +/// +/// Implementations must honor callback cancellation and must not return from +/// `closed` until the session can no longer route requests to the cell. +pub(crate) trait CellHost: Send + Sync + 'static { + fn invoke_tool( + &self, + invocation: CellToolCall, + cancellation_token: CancellationToken, + ) -> impl Future> + Send; + + fn notify( + &self, + call_id: String, + text: String, + cancellation_token: CancellationToken, + ) -> impl Future> + Send; + + fn commit_stored_values( + &self, + stored_value_writes: HashMap, + ) -> impl Future + Send; + + fn closed(&self) -> impl Future + Send; +} + +#[derive(Clone)] +pub(crate) struct CellHandle { + command_tx: mpsc::UnboundedSender, + cancellation_token: CancellationToken, + termination_requested: Arc, +} + +impl CellHandle { + pub(super) fn new( + command_tx: mpsc::UnboundedSender, + cancellation_token: CancellationToken, + ) -> Self { + Self { + command_tx, + cancellation_token, + termination_requested: Arc::new(AtomicBool::new(false)), + } + } + + pub(crate) fn observe(&self, mode: ObserveMode) -> CellEventFuture { + let (response_tx, response_rx) = oneshot::channel(); + if self + .command_tx + .send(CellCommand::Observe { mode, response_tx }) + .is_err() + { + return closed_event(); + } + response_event(response_rx) + } + + pub(crate) fn terminate(&self) -> CellEventFuture { + if self + .termination_requested + .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) + .is_err() + { + return Box::pin(async { Err(CellError::AlreadyTerminating) }); + } + let (response_tx, response_rx) = oneshot::channel(); + if self + .command_tx + .send(CellCommand::Terminate { + response_tx: Some(response_tx), + }) + .is_err() + { + self.termination_requested.store(false, Ordering::Relaxed); + return closed_event(); + } + response_event(response_rx) + } + + pub(crate) fn shutdown(&self) { + self.termination_requested.store(true, Ordering::Relaxed); + self.cancellation_token.cancel(); + let _ = self + .command_tx + .send(CellCommand::Terminate { response_tx: None }); + } +} + +pub(super) enum CellCommand { + Observe { + mode: ObserveMode, + response_tx: oneshot::Sender>, + }, + Terminate { + response_tx: Option>>, + }, +} + +fn response_event(response_rx: oneshot::Receiver>) -> CellEventFuture { + Box::pin(async move { response_rx.await.unwrap_or(Err(CellError::Closed)) }) +} + +fn closed_event() -> CellEventFuture { + Box::pin(async { Err(CellError::Closed) }) +} diff --git a/codex-rs/code-mode/src/lib.rs b/codex-rs/code-mode/src/lib.rs index 269925827..76f8aa9b9 100644 --- a/codex-rs/code-mode/src/lib.rs +++ b/codex-rs/code-mode/src/lib.rs @@ -1,3 +1,4 @@ +mod cell_actor; mod runtime; mod service; diff --git a/codex-rs/code-mode/src/runtime/mod.rs b/codex-rs/code-mode/src/runtime/mod.rs index 36ffd926c..bfc54f71b 100644 --- a/codex-rs/code-mode/src/runtime/mod.rs +++ b/codex-rs/code-mode/src/runtime/mod.rs @@ -25,17 +25,20 @@ pub(crate) enum RuntimeCommand { ToolResponse { id: String, result: JsonValue }, ToolError { id: String, error_text: String }, TimeoutFired { id: u64 }, + ObservePendingFrontier, Terminate, } #[derive(Clone, Copy, Debug, PartialEq)] pub(crate) enum PendingRuntimeMode { + #[cfg(test)] Continue, PauseUntilResumed, } #[derive(Debug)] pub(crate) enum RuntimeControlCommand { + Continue, Resume, Terminate, } @@ -245,6 +248,7 @@ fn run_runtime( return; } } + RuntimeCommand::ObservePendingFrontier => {} } scope.perform_microtask_checkpoint(); @@ -283,8 +287,10 @@ fn next_runtime_command( let _ = event_tx.send(RuntimeEvent::Pending); match pending_mode { + #[cfg(test)] PendingRuntimeMode::Continue => return command_rx.recv().ok(), PendingRuntimeMode::PauseUntilResumed => match control_rx.recv().ok()? { + RuntimeControlCommand::Continue => return command_rx.recv().ok(), RuntimeControlCommand::Resume => continue, RuntimeControlCommand::Terminate => return Some(RuntimeCommand::Terminate), }, diff --git a/codex-rs/code-mode/src/service.rs b/codex-rs/code-mode/src/service.rs index db5cb55a6..4c8fa517a 100644 --- a/codex-rs/code-mode/src/service.rs +++ b/codex-rs/code-mode/src/service.rs @@ -12,10 +12,12 @@ use codex_code_mode_protocol::CodeModeSessionDelegate; use codex_code_mode_protocol::CodeModeSessionProvider; use codex_code_mode_protocol::CodeModeSessionProviderFuture; use codex_code_mode_protocol::CodeModeSessionResultFuture; +use codex_code_mode_protocol::CodeModeToolKind; use codex_code_mode_protocol::DEFAULT_EXEC_YIELD_TIME_MS; use codex_code_mode_protocol::ExecuteRequest; use codex_code_mode_protocol::ExecuteToPendingOutcome; use codex_code_mode_protocol::FunctionCallOutputContentItem; +use codex_code_mode_protocol::ImageDetail; use codex_code_mode_protocol::NotificationFuture; use codex_code_mode_protocol::RuntimeResponse; use codex_code_mode_protocol::StartedCell; @@ -26,17 +28,23 @@ use codex_code_mode_protocol::WaitToPendingOutcome; use codex_code_mode_protocol::WaitToPendingRequest; use serde_json::Value as JsonValue; use tokio::sync::Mutex; -use tokio::sync::mpsc; use tokio::sync::oneshot; -use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; -use tracing::warn; -use crate::runtime::PendingRuntimeMode; -use crate::runtime::RuntimeCommand; -use crate::runtime::RuntimeControlCommand; -use crate::runtime::RuntimeEvent; -use crate::runtime::spawn_runtime; +use crate::cell_actor::CellActor; +use crate::cell_actor::CellError; +use crate::cell_actor::CellEvent; +use crate::cell_actor::CellEventFuture; +use crate::cell_actor::CellHandle; +use crate::cell_actor::CellHost; +use crate::cell_actor::CellImageDetail; +use crate::cell_actor::CellOutputItem; +use crate::cell_actor::CellRequest; +use crate::cell_actor::CellToolCall; +use crate::cell_actor::CellToolDefinition; +use crate::cell_actor::CellToolKind; +use crate::cell_actor::CellToolName; +use crate::cell_actor::ObserveMode; pub struct NoopCodeModeSessionDelegate; @@ -81,14 +89,6 @@ impl CodeModeSessionProvider for InProcessCodeModeSessionProvider { } } -#[derive(Clone)] -struct CellHandle { - control_tx: mpsc::UnboundedSender, - runtime_tx: std::sync::mpsc::Sender, - cancellation_token: CancellationToken, - termination_requested: Arc, -} - struct Inner { stored_values: Mutex>, cells: Mutex>, @@ -128,20 +128,24 @@ impl CodeModeService { } pub async fn execute(&self, request: ExecuteRequest) -> Result { - if self.inner.shutting_down.load(Ordering::Acquire) { - return Err("code mode session is shutting down".to_string()); - } - let initial_yield_time_ms = request.yield_time_ms.unwrap_or(DEFAULT_EXEC_YIELD_TIME_MS); - let (response_tx, response_rx) = oneshot::channel(); + let yield_time_ms = request.yield_time_ms.unwrap_or(DEFAULT_EXEC_YIELD_TIME_MS); let cell_id = self.allocate_cell_id(); - self.start_cell( - cell_id.clone(), - request, - CellResponseSender::Runtime(response_tx), - Some(initial_yield_time_ms), - PendingRuntimeMode::Continue, - ) - .await?; + let initial_event = self + .start_cell( + cell_id.clone(), + request, + ObserveMode::YieldAfter(Duration::from_millis(yield_time_ms)), + ) + .await?; + let response_cell_id = cell_id.clone(); + let (response_tx, response_rx) = oneshot::channel(); + tokio::spawn(async move { + let response = initial_event + .await + .map_err(|error| cell_error_text(&response_cell_id, error)) + .and_then(|event| runtime_response(&response_cell_id, event)); + let _ = response_tx.send(response); + }); Ok(StartedCell::from_result_receiver(cell_id, response_rx)) } @@ -150,75 +154,43 @@ impl CodeModeService { &self, request: ExecuteRequest, ) -> Result { - let (response_tx, response_rx) = oneshot::channel(); let cell_id = self.allocate_cell_id(); - self.start_cell( - cell_id, - request, - CellResponseSender::ExecuteToPending(response_tx), - /*initial_yield_time_ms*/ None, - PendingRuntimeMode::PauseUntilResumed, - ) - .await?; - - response_rx + let event = self + .start_cell(cell_id.clone(), request, ObserveMode::PendingFrontier) + .await? .await - .map_err(|_| "exec runtime ended unexpectedly".to_string())? + .map_err(|error| cell_error_text(&cell_id, error))?; + pending_outcome(&cell_id, event) } async fn start_cell( &self, cell_id: CellId, request: ExecuteRequest, - initial_response_tx: CellResponseSender, - initial_yield_time_ms: Option, - pending_mode: PendingRuntimeMode, - ) -> Result<(), String> { - let (event_tx, event_rx) = mpsc::unbounded_channel(); - let (control_tx, control_rx) = mpsc::unbounded_channel(); + initial_observe_mode: ObserveMode, + ) -> Result { let stored_values = self.inner.stored_values.lock().await.clone(); - let cancellation_token = CancellationToken::new(); - let (runtime_tx, runtime_control_tx, runtime_terminate_handle) = { - let mut cells = self.inner.cells.lock().await; - if self.inner.shutting_down.load(Ordering::Acquire) { - return Err("code mode session is shutting down".to_string()); - } - if cells.contains_key(&cell_id) { - return Err(format!("exec cell {cell_id} already exists")); - } - - let (runtime_tx, runtime_control_tx, runtime_terminate_handle) = - spawn_runtime(stored_values, request, event_tx, pending_mode)?; - - cells.insert( - cell_id.clone(), - CellHandle { - control_tx, - runtime_tx: runtime_tx.clone(), - cancellation_token: cancellation_token.clone(), - termination_requested: Arc::new(AtomicBool::new(false)), - }, - ); - (runtime_tx, runtime_control_tx, runtime_terminate_handle) - }; - - tokio::spawn(run_cell_control( - Arc::clone(&self.inner), - CellControlContext { - cell_id, - runtime_tx, - runtime_control_tx, - pending_mode, - runtime_terminate_handle, - cancellation_token, - }, - event_rx, - control_rx, - initial_response_tx, - initial_yield_time_ms, - )); - - Ok(()) + let host = Arc::new(ServiceCellHost { + cell_id: cell_id.clone(), + inner: Arc::clone(&self.inner), + }); + let mut cells = self.inner.cells.lock().await; + if self.inner.shutting_down.load(Ordering::Acquire) { + return Err("code mode session is shutting down".to_string()); + } + if cells.contains_key(&cell_id) { + return Err(format!("exec cell {cell_id} already exists")); + } + let (handle, initial_event, task) = CellActor::prepare( + cell_request(request), + stored_values, + host, + initial_observe_mode, + )?; + cells.insert(cell_id, handle); + drop(cells); + tokio::spawn(task); + Ok(initial_event) } pub async fn wait(&self, request: WaitRequest) -> Result { @@ -237,15 +209,12 @@ impl CodeModeService { let Some(handle) = handle else { return missing_wait(cell_id); }; - let (response_tx, response_rx) = oneshot::channel(); - let control_message = CellControlCommand::Poll { - yield_time_ms, - response_tx, - }; - if handle.control_tx.send(control_message).is_err() { - return missing_wait(cell_id); - } - wait_for_response(cell_id, response_rx) + wait_for_event( + cell_id, + handle.observe(ObserveMode::YieldAfter(Duration::from_millis( + yield_time_ms, + ))), + ) } pub async fn terminate(&self, cell_id: CellId) -> Result { @@ -253,27 +222,7 @@ impl CodeModeService { let Some(handle) = handle else { return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))); }; - if handle - .termination_requested - .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) - .is_err() - { - return Err(already_terminating_error(&cell_id)); - } - let (response_tx, response_rx) = oneshot::channel(); - if handle - .control_tx - .send(CellControlCommand::Terminate { response_tx }) - .is_err() - { - handle.termination_requested.store(false, Ordering::Relaxed); - return Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))); - } - match response_rx.await { - Ok(Ok(response)) => Ok(WaitOutcome::LiveCell(response)), - Ok(Err(error_text)) => Err(error_text), - Err(_) => Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))), - } + wait_for_event(cell_id, handle.terminate()).await } pub async fn wait_to_pending( @@ -287,22 +236,14 @@ impl CodeModeService { cell_id, ))); }; - let (response_tx, response_rx) = oneshot::channel(); - if handle - .control_tx - .send(CellControlCommand::PollToPending { response_tx }) - .is_err() - { - return Ok(WaitToPendingOutcome::MissingCell(missing_cell_response( - cell_id, - ))); - } - match response_rx.await { - Ok(Ok(response)) => Ok(WaitToPendingOutcome::LiveCell(response)), - Ok(Err(error_text)) => Err(error_text), - Err(_) => Ok(WaitToPendingOutcome::MissingCell(missing_cell_response( + match handle.observe(ObserveMode::PendingFrontier).await { + Ok(event) => Ok(WaitToPendingOutcome::LiveCell(pending_outcome( + &cell_id, event, + )?)), + Err(CellError::Closed) => Ok(WaitToPendingOutcome::MissingCell(missing_cell_response( cell_id, ))), + Err(error) => Err(cell_error_text(&cell_id, error)), } } @@ -317,12 +258,7 @@ impl CodeModeService { .cloned() .collect::>(); for handle in handles { - handle.cancellation_token.cancel(); - let (response_tx, _response_rx) = oneshot::channel(); - let _ = handle - .control_tx - .send(CellControlCommand::Terminate { response_tx }); - let _ = handle.runtime_tx.send(RuntimeCommand::Terminate); + handle.shutdown(); } while !self.inner.cells.lock().await.is_empty() { tokio::task::yield_now().await; @@ -342,12 +278,7 @@ impl Drop for CodeModeService { self.inner.shutting_down.store(true, Ordering::Release); if let Ok(cells) = self.inner.cells.try_lock() { for handle in cells.values() { - handle.cancellation_token.cancel(); - let (response_tx, _response_rx) = oneshot::channel(); - let _ = handle - .control_tx - .send(CellControlCommand::Terminate { response_tx }); - let _ = handle.runtime_tx.send(RuntimeCommand::Terminate); + handle.shutdown(); } } } @@ -378,36 +309,161 @@ impl CodeModeSession for CodeModeService { } } -enum CellControlCommand { - Poll { - yield_time_ms: u64, - response_tx: oneshot::Sender>, - }, - PollToPending { - response_tx: oneshot::Sender>, - }, - Terminate { - response_tx: oneshot::Sender>, - }, -} - -enum CellResponseSender { - Runtime(oneshot::Sender>), - ExecuteToPending(oneshot::Sender>), -} - -struct PendingResult { - content_items: Vec, - error_text: Option, -} - -struct CellControlContext { +struct ServiceCellHost { cell_id: CellId, - runtime_tx: std::sync::mpsc::Sender, - runtime_control_tx: std::sync::mpsc::Sender, - pending_mode: PendingRuntimeMode, - runtime_terminate_handle: v8::IsolateHandle, - cancellation_token: CancellationToken, + inner: Arc, +} + +impl CellHost for ServiceCellHost { + async fn invoke_tool( + &self, + invocation: CellToolCall, + cancellation_token: CancellationToken, + ) -> Result { + self.inner + .delegate + .invoke_tool( + CodeModeNestedToolCall { + cell_id: self.cell_id.clone(), + runtime_tool_call_id: invocation.id, + tool_name: codex_protocol::ToolName { + name: invocation.name.name, + namespace: invocation.name.namespace, + }, + tool_kind: match invocation.kind { + CellToolKind::Function => CodeModeToolKind::Function, + CellToolKind::Freeform => CodeModeToolKind::Freeform, + }, + input: invocation.input, + }, + cancellation_token, + ) + .await + } + + async fn notify( + &self, + call_id: String, + text: String, + cancellation_token: CancellationToken, + ) -> Result<(), String> { + self.inner + .delegate + .notify(call_id, self.cell_id.clone(), text, cancellation_token) + .await + } + + async fn commit_stored_values(&self, stored_value_writes: HashMap) { + self.inner + .stored_values + .lock() + .await + .extend(stored_value_writes); + } + + async fn closed(&self) { + self.inner.cells.lock().await.remove(&self.cell_id); + self.inner.delegate.cell_closed(&self.cell_id); + } +} + +fn cell_request(request: ExecuteRequest) -> CellRequest { + CellRequest { + tool_call_id: request.tool_call_id, + enabled_tools: request + .enabled_tools + .into_iter() + .map(|definition| CellToolDefinition { + name: definition.name, + tool_name: CellToolName { + name: definition.tool_name.name, + namespace: definition.tool_name.namespace, + }, + description: definition.description, + kind: match definition.kind { + CodeModeToolKind::Function => CellToolKind::Function, + CodeModeToolKind::Freeform => CellToolKind::Freeform, + }, + }) + .collect(), + source: request.source, + } +} + +fn wait_for_event( + cell_id: CellId, + event: CellEventFuture, +) -> CodeModeSessionResultFuture<'static, WaitOutcome> { + Box::pin(async move { + match event.await { + Ok(event) => Ok(WaitOutcome::LiveCell(runtime_response(&cell_id, event)?)), + Err(CellError::Closed) => Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))), + Err(error) => Err(cell_error_text(&cell_id, error)), + } + }) +} + +fn pending_outcome(cell_id: &CellId, event: CellEvent) -> Result { + match event { + CellEvent::Pending { + content_items, + pending_tool_call_ids, + } => Ok(ExecuteToPendingOutcome::Pending { + cell_id: cell_id.clone(), + content_items: content_items.into_iter().map(output_item).collect(), + pending_tool_call_ids, + }), + event => Ok(ExecuteToPendingOutcome::Completed(runtime_response( + cell_id, event, + )?)), + } +} + +fn runtime_response(cell_id: &CellId, event: CellEvent) -> Result { + match event { + CellEvent::Yielded { content_items } => Ok(RuntimeResponse::Yielded { + cell_id: cell_id.clone(), + content_items: content_items.into_iter().map(output_item).collect(), + }), + CellEvent::Completed { + content_items, + error_text, + } => Ok(RuntimeResponse::Result { + cell_id: cell_id.clone(), + content_items: content_items.into_iter().map(output_item).collect(), + error_text, + }), + CellEvent::Terminated { content_items } => Ok(RuntimeResponse::Terminated { + cell_id: cell_id.clone(), + content_items: content_items.into_iter().map(output_item).collect(), + }), + CellEvent::Pending { .. } => { + Err("cell returned a pending frontier unexpectedly".to_string()) + } + } +} + +fn output_item(item: CellOutputItem) -> FunctionCallOutputContentItem { + match item { + CellOutputItem::Text { text } => FunctionCallOutputContentItem::InputText { text }, + CellOutputItem::Image { image_url, detail } => FunctionCallOutputContentItem::InputImage { + image_url, + detail: detail.map(|detail| match detail { + CellImageDetail::Auto => ImageDetail::Auto, + CellImageDetail::Low => ImageDetail::Low, + CellImageDetail::High => ImageDetail::High, + CellImageDetail::Original => ImageDetail::Original, + }), + }, + } +} + +fn cell_error_text(cell_id: &CellId, error: CellError) -> String { + match error { + CellError::Busy => format!("exec cell {cell_id} already has an active observer"), + CellError::AlreadyTerminating => format!("exec cell {cell_id} is already terminating"), + CellError::Closed => format!("exec cell {cell_id} closed unexpectedly"), + } } fn missing_cell_response(cell_id: CellId) -> RuntimeResponse { @@ -422,1528 +478,9 @@ fn missing_wait(cell_id: CellId) -> CodeModeSessionResultFuture<'static, WaitOut Box::pin(async move { Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))) }) } -fn wait_for_response( - cell_id: CellId, - response_rx: oneshot::Receiver>, -) -> CodeModeSessionResultFuture<'static, WaitOutcome> { - Box::pin(async move { - match response_rx.await { - Ok(Ok(response)) => Ok(WaitOutcome::LiveCell(response)), - Ok(Err(error_text)) => Err(error_text), - Err(_) => Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))), - } - }) -} - -fn busy_observer_error(cell_id: &CellId) -> String { - format!("exec cell {cell_id} already has an active observer") -} - -fn already_terminating_error(cell_id: &CellId) -> String { - format!("exec cell {cell_id} is already terminating") -} - -fn pending_result_response(cell_id: &CellId, result: PendingResult) -> RuntimeResponse { - RuntimeResponse::Result { - cell_id: cell_id.clone(), - content_items: result.content_items, - error_text: result.error_text, - } -} - -fn send_terminal_response(response_tx: CellResponseSender, response: RuntimeResponse) { - match response_tx { - CellResponseSender::Runtime(response_tx) => { - let _ = response_tx.send(Ok(response)); - } - CellResponseSender::ExecuteToPending(response_tx) => { - let _ = response_tx.send(Ok(ExecuteToPendingOutcome::Completed(response))); - } - } -} - -fn send_termination_responses( - response_tx: Option, - termination_response_tx: Option>>, - response: RuntimeResponse, -) { - if let Some(response_tx) = response_tx { - send_terminal_response(response_tx, response.clone()); - } - if let Some(termination_response_tx) = termination_response_tx { - let _ = termination_response_tx.send(Ok(response)); - } -} - -fn send_or_buffer_result( - cell_id: &CellId, - result: PendingResult, - response_tx: &mut Option, - pending_result: &mut Option, -) -> bool { - if let Some(response_tx) = response_tx.take() { - let response = pending_result_response(cell_id, result); - send_terminal_response(response_tx, response); - return true; - } - - *pending_result = Some(result); - false -} - -fn send_yield_response( - cell_id: &CellId, - content_items: &mut Vec, - response_tx: &mut Option, -) { - let Some(current_response_tx) = response_tx.take() else { - return; - }; - match current_response_tx { - CellResponseSender::Runtime(response_tx) => { - let _ = response_tx.send(Ok(RuntimeResponse::Yielded { - cell_id: cell_id.clone(), - content_items: std::mem::take(content_items), - })); - } - CellResponseSender::ExecuteToPending(execute_to_pending_tx) => { - *response_tx = Some(CellResponseSender::ExecuteToPending(execute_to_pending_tx)); - } - } -} - -async fn run_cell_control( - inner: Arc, - context: CellControlContext, - mut event_rx: mpsc::UnboundedReceiver, - mut control_rx: mpsc::UnboundedReceiver, - initial_response_tx: CellResponseSender, - initial_yield_time_ms: Option, -) { - let CellControlContext { - cell_id, - runtime_tx, - runtime_control_tx, - pending_mode, - runtime_terminate_handle, - cancellation_token, - } = context; - let mut content_items = Vec::new(); - let mut pending_tool_call_ids = Vec::new(); - let mut pending_result: Option = None; - let mut response_tx = Some(initial_response_tx); - let mut termination_response_tx = None; - let mut termination_requested = false; - let mut runtime_closed = false; - let mut yield_timer: Option>> = None; - let mut notification_tasks = JoinSet::new(); - let mut tool_tasks = JoinSet::new(); - - loop { - let yield_deadline_elapsed = yield_timer - .as_ref() - .is_some_and(|yield_timer| yield_timer.deadline() <= tokio::time::Instant::now()); - tokio::select! { - biased; - maybe_command = control_rx.recv() => { - let Some(command) = maybe_command else { - break; - }; - match command { - CellControlCommand::Poll { - yield_time_ms, - response_tx: next_response_tx, - } => { - if let Some(result) = pending_result.take() { - let _ = next_response_tx.send(Ok(pending_result_response(&cell_id, result))); - break; - } - if response_tx.is_some() || termination_response_tx.is_some() { - let _ = next_response_tx.send(Err(busy_observer_error(&cell_id))); - continue; - } - response_tx = Some(CellResponseSender::Runtime(next_response_tx)); - yield_timer = Some(Box::pin(tokio::time::sleep(Duration::from_millis(yield_time_ms)))); - resume_paused_runtime(&runtime_control_tx, pending_mode); - } - CellControlCommand::PollToPending { - response_tx: next_response_tx, - } => { - if let Some(result) = pending_result.take() { - let response = pending_result_response(&cell_id, result); - let _ = next_response_tx - .send(Ok(ExecuteToPendingOutcome::Completed(response))); - break; - } - if response_tx.is_some() || termination_response_tx.is_some() { - let _ = next_response_tx.send(Err(busy_observer_error(&cell_id))); - continue; - } - response_tx = - Some(CellResponseSender::ExecuteToPending(next_response_tx)); - yield_timer = None; - resume_paused_runtime(&runtime_control_tx, pending_mode); - } - CellControlCommand::Terminate { response_tx: next_response_tx } => { - if let Some(result) = pending_result.take() { - let _ = next_response_tx.send(Ok(pending_result_response(&cell_id, result))); - break; - } - - if termination_response_tx.is_some() { - let _ = next_response_tx.send(Err(already_terminating_error(&cell_id))); - continue; - } - - termination_response_tx = Some(next_response_tx); - termination_requested = true; - cancellation_token.cancel(); - yield_timer = None; - let _ = runtime_tx.send(RuntimeCommand::Terminate); - terminate_paused_runtime(&runtime_control_tx, pending_mode); - let _ = runtime_terminate_handle.terminate_execution(); - if runtime_closed { - finish_callbacks( - &cancellation_token, - &mut notification_tasks, - &mut tool_tasks, - CallbackCompletion::Cancel, - ).await; - let response = RuntimeResponse::Terminated { - cell_id: cell_id.clone(), - content_items: std::mem::take(&mut content_items), - }; - send_termination_responses( - response_tx.take(), - termination_response_tx.take(), - response, - ); - break; - } else { - continue; - } - } - } - } - _ = async { - if let Some(yield_timer) = yield_timer.as_mut() { - yield_timer.await; - } else { - std::future::pending::<()>().await; - } - } => { - yield_timer = None; - send_yield_response(&cell_id, &mut content_items, &mut response_tx); - } - maybe_event = async { - if runtime_closed { - std::future::pending::>().await - } else { - event_rx.recv().await - } - }, if !yield_deadline_elapsed => { - let Some(event) = maybe_event else { - runtime_closed = true; - if termination_requested { - finish_callbacks( - &cancellation_token, - &mut notification_tasks, - &mut tool_tasks, - CallbackCompletion::Cancel, - ).await; - let response = RuntimeResponse::Terminated { - cell_id: cell_id.clone(), - content_items: std::mem::take(&mut content_items), - }; - send_termination_responses( - response_tx.take(), - termination_response_tx.take(), - response, - ); - break; - } - if pending_result.is_none() { - let result = PendingResult { - content_items: std::mem::take(&mut content_items), - error_text: Some("exec runtime ended unexpectedly".to_string()), - }; - if send_or_buffer_result( - &cell_id, - result, - &mut response_tx, - &mut pending_result, - ) { - break; - } - } - continue; - }; - match event { - RuntimeEvent::Started => { - yield_timer = initial_yield_time_ms.map(|initial_yield_time_ms| { - Box::pin(tokio::time::sleep(Duration::from_millis(initial_yield_time_ms))) - }); - } - RuntimeEvent::Pending => { - if let Some(current_response_tx) = response_tx.take() { - match current_response_tx { - CellResponseSender::Runtime(runtime_response_tx) => { - response_tx = - Some(CellResponseSender::Runtime(runtime_response_tx)); - } - CellResponseSender::ExecuteToPending(response_tx) => { - let _ = response_tx.send(Ok(ExecuteToPendingOutcome::Pending { - cell_id: cell_id.clone(), - content_items: std::mem::take(&mut content_items), - pending_tool_call_ids: std::mem::take( - &mut pending_tool_call_ids, - ), - })); - } - } - } - } - RuntimeEvent::ContentItem(item) => { - content_items.push(item); - } - RuntimeEvent::YieldRequested => { - yield_timer = None; - send_yield_response(&cell_id, &mut content_items, &mut response_tx); - } - RuntimeEvent::Notify { call_id, text } => { - let delegate = Arc::clone(&inner.delegate); - let cell_id = cell_id.clone(); - let cancellation_token = cancellation_token.child_token(); - notification_tasks.spawn(async move { - if let Err(err) = delegate - .notify(call_id, cell_id.clone(), text, cancellation_token) - .await - { - warn!( - "failed to deliver code mode notification for cell {cell_id}: {err}" - ); - } - }); - } - RuntimeEvent::ToolCall { - id, - name, - kind, - input, - } => { - if pending_mode == PendingRuntimeMode::PauseUntilResumed { - pending_tool_call_ids.push(id.clone()); - } - let tool_call = CodeModeNestedToolCall { - cell_id: cell_id.clone(), - runtime_tool_call_id: id.clone(), - tool_name: name, - tool_kind: kind, - input, - }; - let delegate = Arc::clone(&inner.delegate); - let runtime_tx = runtime_tx.clone(); - let cancellation_token = cancellation_token.child_token(); - tool_tasks.spawn(async move { - let response = delegate.invoke_tool(tool_call, cancellation_token).await; - let command = match response { - Ok(result) => RuntimeCommand::ToolResponse { id, result }, - Err(error_text) => RuntimeCommand::ToolError { id, error_text }, - }; - let _ = runtime_tx.send(command); - }); - } - RuntimeEvent::Result { - stored_value_writes, - error_text, - } => { - yield_timer = None; - if termination_requested { - finish_callbacks( - &cancellation_token, - &mut notification_tasks, - &mut tool_tasks, - CallbackCompletion::Cancel, - ).await; - let response = RuntimeResponse::Terminated { - cell_id: cell_id.clone(), - content_items: std::mem::take(&mut content_items), - }; - send_termination_responses( - response_tx.take(), - termination_response_tx.take(), - response, - ); - break; - } - finish_callbacks( - &cancellation_token, - &mut notification_tasks, - &mut tool_tasks, - CallbackCompletion::DrainNotifications, - ).await; - inner - .stored_values - .lock() - .await - .extend(stored_value_writes); - let result = PendingResult { - content_items: std::mem::take(&mut content_items), - error_text, - }; - if send_or_buffer_result( - &cell_id, - result, - &mut response_tx, - &mut pending_result, - ) { - break; - } - } - } - } - task_result = notification_tasks.join_next(), if !notification_tasks.is_empty() => { - if let Some(Err(err)) = task_result - && !err.is_cancelled() - { - warn!("code mode notification task failed: {err}"); - } - } - task_result = tool_tasks.join_next(), if !tool_tasks.is_empty() => { - if let Some(Err(err)) = task_result - && !err.is_cancelled() - { - warn!("code mode tool task failed: {err}"); - } - } - } - } - - let _ = runtime_tx.send(RuntimeCommand::Terminate); - cancellation_token.cancel(); - finish_callbacks( - &cancellation_token, - &mut notification_tasks, - &mut tool_tasks, - CallbackCompletion::Cancel, - ) - .await; - terminate_paused_runtime(&runtime_control_tx, pending_mode); - inner.cells.lock().await.remove(&cell_id); - inner.delegate.cell_closed(&cell_id); -} - -#[derive(Clone, Copy)] -enum CallbackCompletion { - DrainNotifications, - Cancel, -} - -async fn finish_callbacks( - cancellation_token: &CancellationToken, - notification_tasks: &mut JoinSet<()>, - tool_tasks: &mut JoinSet<()>, - completion: CallbackCompletion, -) { - if matches!(completion, CallbackCompletion::Cancel) { - cancellation_token.cancel(); - } - drain_tasks(notification_tasks, "notification").await; - cancellation_token.cancel(); - drain_tasks(tool_tasks, "tool").await; -} - -async fn drain_tasks(tasks: &mut JoinSet<()>, description: &str) { - while let Some(result) = tasks.join_next().await { - if let Err(err) = result - && !err.is_cancelled() - { - warn!("code mode {description} task failed: {err}"); - } - } -} - -fn resume_paused_runtime( - runtime_control_tx: &std::sync::mpsc::Sender, - pending_mode: PendingRuntimeMode, -) { - if pending_mode == PendingRuntimeMode::PauseUntilResumed { - let _ = runtime_control_tx.send(RuntimeControlCommand::Resume); - } -} - -fn terminate_paused_runtime( - runtime_control_tx: &std::sync::mpsc::Sender, - pending_mode: PendingRuntimeMode, -) { - if pending_mode == PendingRuntimeMode::PauseUntilResumed { - let _ = runtime_control_tx.send(RuntimeControlCommand::Terminate); - } -} - #[cfg(test)] -mod tests { - use std::collections::HashMap; - use std::sync::Arc; - use std::sync::atomic::AtomicU64; - use std::sync::atomic::Ordering; - use std::time::Duration; - - use codex_protocol::ToolName; - use pretty_assertions::assert_eq; - use tokio::sync::Mutex; - use tokio::sync::mpsc; - use tokio::sync::oneshot; - - use super::CellControlCommand; - use super::CellControlContext; - use super::CellId; - use super::CellResponseSender; - use super::CodeModeService; - use super::Inner; - use super::NoopCodeModeSessionDelegate; - use super::PendingRuntimeMode; - use super::RuntimeCommand; - use super::RuntimeResponse; - use super::WaitOutcome; - use super::WaitRequest; - use super::WaitToPendingOutcome; - use super::WaitToPendingRequest; - use super::run_cell_control; - use crate::CodeModeToolKind; - use crate::ExecuteRequest; - use crate::ExecuteToPendingOutcome; - use crate::FunctionCallOutputContentItem; - use crate::ToolDefinition; - use crate::runtime::RuntimeEvent; - use crate::runtime::spawn_runtime; - - fn execute_request(source: &str) -> ExecuteRequest { - ExecuteRequest { - tool_call_id: "call_1".to_string(), - enabled_tools: Vec::new(), - source: source.to_string(), - yield_time_ms: Some(1), - max_output_tokens: None, - } - } - - fn cell_id(value: &str) -> CellId { - CellId::new(value.to_string()) - } - - async fn execute(service: &CodeModeService, request: ExecuteRequest) -> RuntimeResponse { - service - .execute(request) - .await - .unwrap() - .initial_response() - .await - .unwrap() - } - - fn test_inner() -> Arc { - Arc::new(Inner { - stored_values: Mutex::new(HashMap::new()), - cells: Mutex::new(HashMap::new()), - delegate: Arc::new(NoopCodeModeSessionDelegate), - shutting_down: std::sync::atomic::AtomicBool::new(false), - next_cell_id: AtomicU64::new(1), - }) - } - - #[tokio::test] - async fn synchronous_exit_returns_successfully() { - let service = CodeModeService::new(); - - let response = execute( - &service, - ExecuteRequest { - source: r#"text("before"); exit(); text("after");"#.to_string(), - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - - assert_eq!( - response, - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: vec![FunctionCallOutputContentItem::InputText { - text: "before".to_string(), - }], - error_text: None, - } - ); - } - - #[tokio::test] - async fn stored_values_are_shared_between_cells_but_not_sessions() { - let first_session = CodeModeService::new(); - let second_session = CodeModeService::new(); - - let write_response = execute( - &first_session, - ExecuteRequest { - source: r#"store("key", "visible");"#.to_string(), - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - - let same_session = execute( - &first_session, - ExecuteRequest { - source: r#"text(String(load("key")));"#.to_string(), - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - let other_session = execute( - &second_session, - ExecuteRequest { - source: r#"text(String(load("key")));"#.to_string(), - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - - assert_eq!( - write_response, - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: Vec::new(), - error_text: None, - } - ); - assert_eq!( - same_session, - RuntimeResponse::Result { - cell_id: cell_id("2"), - content_items: vec![FunctionCallOutputContentItem::InputText { - text: "visible".to_string(), - }], - error_text: None, - } - ); - assert_eq!( - other_session, - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: vec![FunctionCallOutputContentItem::InputText { - text: "undefined".to_string(), - }], - error_text: None, - } - ); - } - - #[tokio::test] - async fn shutdown_interrupts_cpu_bound_cells() { - let service = CodeModeService::new(); - - let cell = service - .execute(ExecuteRequest { - source: "while (true) {}".to_string(), - ..execute_request("") - }) - .await - .unwrap(); - assert_eq!( - cell.initial_response().await.unwrap(), - RuntimeResponse::Yielded { - cell_id: cell_id("1"), - content_items: Vec::new(), - } - ); - - tokio::time::timeout(Duration::from_secs(1), service.shutdown()) - .await - .unwrap() - .unwrap(); - } - - #[tokio::test] - async fn start_cell_rejects_new_cell_after_shutdown_begins() { - let service = CodeModeService::new(); - service.inner.shutting_down.store(true, Ordering::Release); - let (response_tx, _response_rx) = oneshot::channel(); - - let error = service - .start_cell( - cell_id("late-cell"), - execute_request(""), - CellResponseSender::Runtime(response_tx), - Some(/*initial_yield_time_ms*/ 1), - PendingRuntimeMode::Continue, - ) - .await - .unwrap_err(); - - assert_eq!(error, "code mode session is shutting down".to_string()); - assert!(service.inner.cells.lock().await.is_empty()); - } - - #[tokio::test] - async fn execute_to_pending_returns_completed_for_synchronous_results() { - let service = CodeModeService::new(); - - let response = service - .execute_to_pending(ExecuteRequest { - source: r#"text("done");"#.to_string(), - yield_time_ms: Some(60_000), - ..execute_request("") - }) - .await - .unwrap(); - - assert_eq!( - response, - ExecuteToPendingOutcome::Completed(RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: vec![FunctionCallOutputContentItem::InputText { - text: "done".to_string(), - }], - error_text: None, - }) - ); - } - - #[tokio::test] - async fn execute_to_pending_returns_once_the_runtime_is_quiescent() { - let service = CodeModeService::new(); - - let response = tokio::time::timeout( - Duration::from_secs(1), - service.execute_to_pending(ExecuteRequest { - source: r#"text("before"); await new Promise(() => {});"#.to_string(), - yield_time_ms: Some(60_000), - ..execute_request("") - }), - ) - .await - .unwrap() - .unwrap(); - - assert_eq!( - response, - ExecuteToPendingOutcome::Pending { - cell_id: cell_id("1"), - content_items: vec![FunctionCallOutputContentItem::InputText { - text: "before".to_string(), - }], - pending_tool_call_ids: Vec::new(), - } - ); - - let termination = service.terminate(cell_id("1")).await.unwrap(); - - assert_eq!( - termination, - WaitOutcome::LiveCell(RuntimeResponse::Terminated { - cell_id: cell_id("1"), - content_items: Vec::new(), - }) - ); - } - - #[tokio::test] - async fn execute_to_pending_identifies_tool_calls_in_paused_frontier() { - let service = CodeModeService::new(); - - let response = service - .execute_to_pending(ExecuteRequest { - enabled_tools: vec![ToolDefinition { - name: "echo".to_string(), - tool_name: ToolName::plain("echo"), - description: String::new(), - kind: CodeModeToolKind::Function, - input_schema: None, - output_schema: None, - }], - source: r#" -await Promise.all([ - tools.echo({ value: "first" }), - tools.echo({ value: "second" }), -]); -"# - .to_string(), - yield_time_ms: Some(60_000), - ..execute_request("") - }) - .await - .unwrap(); - - assert_eq!( - response, - ExecuteToPendingOutcome::Pending { - cell_id: cell_id("1"), - content_items: Vec::new(), - pending_tool_call_ids: vec!["tool-1".to_string(), "tool-2".to_string()], - } - ); - - let termination = service.terminate(cell_id("1")).await.unwrap(); - - assert_eq!( - termination, - WaitOutcome::LiveCell(RuntimeResponse::Terminated { - cell_id: cell_id("1"), - content_items: Vec::new(), - }) - ); - } - - #[tokio::test] - async fn execute_to_pending_excludes_delayed_timeout_tool_calls_until_wait() { - let service = CodeModeService::new(); - - let initial_response = service - .execute_to_pending(ExecuteRequest { - enabled_tools: vec![ToolDefinition { - name: "echo".to_string(), - tool_name: ToolName::plain("echo"), - description: String::new(), - kind: CodeModeToolKind::Function, - input_schema: None, - output_schema: None, - }], - source: r#" -setTimeout(() => { - tools.echo({ value: "delayed" }); -}, 1000); -await Promise.all([ - tools.echo({ value: "second" }), - tools.echo({ value: "third" }), -]); -"# - .to_string(), - yield_time_ms: Some(60_000), - ..execute_request("") - }) - .await - .unwrap(); - - assert_eq!( - initial_response, - ExecuteToPendingOutcome::Pending { - cell_id: cell_id("1"), - content_items: Vec::new(), - pending_tool_call_ids: vec!["tool-1".to_string(), "tool-2".to_string()], - } - ); - - let runtime_tx = service - .inner - .cells - .lock() - .await - .get(&cell_id("1")) - .unwrap() - .runtime_tx - .clone(); - runtime_tx - .send(RuntimeCommand::TimeoutFired { id: 1 }) - .unwrap(); - - let resumed_response = tokio::time::timeout( - Duration::from_secs(1), - service.wait_to_pending(WaitToPendingRequest { - cell_id: cell_id("1"), - }), - ) - .await - .unwrap() - .unwrap(); - - assert_eq!( - resumed_response, - WaitToPendingOutcome::LiveCell(ExecuteToPendingOutcome::Pending { - cell_id: cell_id("1"), - content_items: Vec::new(), - pending_tool_call_ids: vec!["tool-3".to_string()], - }) - ); - - let termination = service.terminate(cell_id("1")).await.unwrap(); - - assert_eq!( - termination, - WaitOutcome::LiveCell(RuntimeResponse::Terminated { - cell_id: cell_id("1"), - content_items: Vec::new(), - }) - ); - } - - #[tokio::test] - async fn wait_to_pending_returns_after_resumed_runtime_becomes_quiescent_again() { - let service = CodeModeService::new(); - - let initial_response = service - .execute_to_pending(ExecuteRequest { - source: r#" -await new Promise((resolve) => setTimeout(resolve, 60_000)); -text("after"); -await new Promise(() => {}); -"# - .to_string(), - yield_time_ms: Some(60_000), - ..execute_request("") - }) - .await - .unwrap(); - - assert_eq!( - initial_response, - ExecuteToPendingOutcome::Pending { - cell_id: cell_id("1"), - content_items: Vec::new(), - pending_tool_call_ids: Vec::new(), - } - ); - - let runtime_tx = service - .inner - .cells - .lock() - .await - .get(&cell_id("1")) - .unwrap() - .runtime_tx - .clone(); - runtime_tx - .send(RuntimeCommand::TimeoutFired { id: 1 }) - .unwrap(); - - let resumed_response = tokio::time::timeout( - Duration::from_secs(1), - service.wait_to_pending(WaitToPendingRequest { - cell_id: cell_id("1"), - }), - ) - .await - .unwrap() - .unwrap(); - - assert_eq!( - resumed_response, - WaitToPendingOutcome::LiveCell(ExecuteToPendingOutcome::Pending { - cell_id: cell_id("1"), - content_items: vec![FunctionCallOutputContentItem::InputText { - text: "after".to_string(), - }], - pending_tool_call_ids: Vec::new(), - }) - ); - - let termination = service.terminate(cell_id("1")).await.unwrap(); - - assert_eq!( - termination, - WaitOutcome::LiveCell(RuntimeResponse::Terminated { - cell_id: cell_id("1"), - content_items: Vec::new(), - }) - ); - } - - #[tokio::test] - async fn wait_to_pending_returns_completed_after_resumed_runtime_finishes() { - let service = CodeModeService::new(); - - let initial_response = service - .execute_to_pending(ExecuteRequest { - source: r#" -await new Promise((resolve) => setTimeout(resolve, 60_000)); -text("done"); -"# - .to_string(), - yield_time_ms: Some(60_000), - ..execute_request("") - }) - .await - .unwrap(); - - assert_eq!( - initial_response, - ExecuteToPendingOutcome::Pending { - cell_id: cell_id("1"), - content_items: Vec::new(), - pending_tool_call_ids: Vec::new(), - } - ); - - let runtime_tx = service - .inner - .cells - .lock() - .await - .get(&cell_id("1")) - .unwrap() - .runtime_tx - .clone(); - runtime_tx - .send(RuntimeCommand::TimeoutFired { id: 1 }) - .unwrap(); - - let resumed_response = tokio::time::timeout( - Duration::from_secs(1), - service.wait_to_pending(WaitToPendingRequest { - cell_id: cell_id("1"), - }), - ) - .await - .unwrap() - .unwrap(); - - assert_eq!( - resumed_response, - WaitToPendingOutcome::LiveCell(ExecuteToPendingOutcome::Completed( - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: vec![FunctionCallOutputContentItem::InputText { - text: "done".to_string(), - }], - error_text: None, - } - )) - ); - } - - #[tokio::test] - async fn v8_console_is_not_exposed_on_global_this() { - let service = CodeModeService::new(); - - let response = execute( - &service, - ExecuteRequest { - source: r#"text(String(Object.hasOwn(globalThis, "console")));"#.to_string(), - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - - assert_eq!( - response, - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: vec![FunctionCallOutputContentItem::InputText { - text: "false".to_string(), - }], - error_text: None, - } - ); - } - - #[tokio::test] - async fn date_locale_string_formats_with_icu_data() { - let service = CodeModeService::new(); - - let response = execute( - &service, - ExecuteRequest { - source: r#" -const value = new Date("2025-01-02T03:04:05Z") - .toLocaleString("fr-FR", { - weekday: "long", - month: "long", - day: "numeric", - hour: "2-digit", - minute: "2-digit", - second: "2-digit", - hour12: false, - timeZone: "UTC", - }); -text(value); -"# - .to_string(), - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - - assert_eq!( - response, - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: vec![FunctionCallOutputContentItem::InputText { - text: "jeudi 2 janvier \u{e0} 03:04:05".to_string(), - }], - error_text: None, - } - ); - } - - #[tokio::test] - async fn intl_date_time_format_formats_with_icu_data() { - let service = CodeModeService::new(); - - let response = execute( - &service, - ExecuteRequest { - source: r#" -const formatter = new Intl.DateTimeFormat("fr-FR", { - weekday: "long", - month: "long", - day: "numeric", - hour: "2-digit", - minute: "2-digit", - second: "2-digit", - hour12: false, - timeZone: "UTC", -}); -text(formatter.format(new Date("2025-01-02T03:04:05Z"))); -"# - .to_string(), - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - - assert_eq!( - response, - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: vec![FunctionCallOutputContentItem::InputText { - text: "jeudi 2 janvier \u{e0} 03:04:05".to_string(), - }], - error_text: None, - } - ); - } - - #[tokio::test] - async fn output_helpers_return_undefined() { - let service = CodeModeService::new(); - - let response = execute( - &service, - ExecuteRequest { - source: r#" -const returnsUndefined = [ - text("first"), - image("data:image/png;base64,AAA"), - notify("ping"), -].map((value) => value === undefined); -text(JSON.stringify(returnsUndefined)); -"# - .to_string(), - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - - assert_eq!( - response, - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: vec![ - FunctionCallOutputContentItem::InputText { - text: "first".to_string(), - }, - FunctionCallOutputContentItem::InputImage { - image_url: "data:image/png;base64,AAA".to_string(), - detail: Some(crate::DEFAULT_IMAGE_DETAIL), - }, - FunctionCallOutputContentItem::InputText { - text: "[true,true,true]".to_string(), - }, - ], - error_text: None, - } - ); - } - - #[tokio::test] - async fn image_helper_accepts_raw_mcp_image_block_with_original_detail() { - let service = CodeModeService::new(); - - let response = execute( - &service, - ExecuteRequest { - source: r#" -image({ - type: "image", - data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", - mimeType: "image/png", - _meta: { "codex/imageDetail": "original" }, -}); -"# - .to_string(), - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - - assert_eq!( - response, - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: vec![FunctionCallOutputContentItem::InputImage { - image_url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==".to_string(), - detail: Some(crate::ImageDetail::Original), - }], - error_text: None, - } - ); - } - - #[tokio::test] - async fn generated_image_helper_appends_image_and_output_hint() { - let service = CodeModeService::new(); - - let response = execute( - &service, - ExecuteRequest { - source: r#" -generatedImage({ - image_url: "data:image/png;base64,AAA", - output_hint: "generated image save hint", -}); -"# - .to_string(), - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - - assert_eq!( - response, - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: vec![ - FunctionCallOutputContentItem::InputImage { - image_url: "data:image/png;base64,AAA".to_string(), - detail: Some(crate::DEFAULT_IMAGE_DETAIL), - }, - FunctionCallOutputContentItem::InputText { - text: "generated image save hint".to_string(), - }, - ], - error_text: None, - } - ); - } - - #[tokio::test] - async fn image_helper_second_arg_overrides_explicit_object_detail() { - let service = CodeModeService::new(); - - let response = execute( - &service, - ExecuteRequest { - source: r#" -image( - { - image_url: "data:image/png;base64,AAA", - detail: "high", - }, - "original", -); -"# - .to_string(), - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - - assert_eq!( - response, - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: vec![FunctionCallOutputContentItem::InputImage { - image_url: "data:image/png;base64,AAA".to_string(), - detail: Some(crate::ImageDetail::Original), - }], - error_text: None, - } - ); - } - - #[tokio::test] - async fn image_helper_second_arg_overrides_raw_mcp_image_detail() { - let service = CodeModeService::new(); - - let response = execute( - &service, - ExecuteRequest { - source: r#" -image( - { - type: "image", - data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", - mimeType: "image/png", - _meta: { "codex/imageDetail": "original" }, - }, - "high", -); -"# - .to_string(), - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - - assert_eq!( - response, - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: vec![FunctionCallOutputContentItem::InputImage { - image_url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==".to_string(), - detail: Some(crate::ImageDetail::High), - }], - error_text: None, - } - ); - } - - #[tokio::test] - async fn image_helper_accepts_low_detail() { - let service = CodeModeService::new(); - - let response = execute( - &service, - ExecuteRequest { - source: r#" -image({ - image_url: "data:image/png;base64,AAA", - detail: "low", -}); -"# - .to_string(), - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - - assert_eq!( - response, - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: vec![FunctionCallOutputContentItem::InputImage { - image_url: "data:image/png;base64,AAA".to_string(), - detail: Some(crate::ImageDetail::Low), - }], - error_text: None, - } - ); - } - - #[tokio::test] - async fn image_helpers_reject_remote_urls() { - for image_url in [ - "http://example.com/image.jpg", - "https://example.com/image.jpg", - ] { - for source in [ - format!("image({image_url:?});"), - format!("generatedImage({{ image_url: {image_url:?} }});"), - ] { - let service = CodeModeService::new(); - - let response = execute( - &service, - ExecuteRequest { - source, - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - - assert_eq!( - response, - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: Vec::new(), - error_text: Some( - "Tool call failed: remote image URLs are not supported in tool outputs. Pass a base64 data URI instead".to_string(), - ), - } - ); - } - } - } - - #[tokio::test] - async fn image_helper_rejects_unsupported_detail() { - let service = CodeModeService::new(); - - let response = execute( - &service, - ExecuteRequest { - source: r#" -image({ - image_url: "data:image/png;base64,AAA", - detail: "medium", -}); -"# - .to_string(), - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - - assert_eq!( - response, - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: Vec::new(), - error_text: Some( - "image detail must be one of: auto, low, high, original".to_string() - ), - } - ); - } - - #[tokio::test] - async fn image_helper_rejects_raw_mcp_result_container() { - let service = CodeModeService::new(); - - let response = execute( - &service, - ExecuteRequest { - source: r#" -image({ - content: [ - { - type: "image", - data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", - mimeType: "image/png", - _meta: { "codex/imageDetail": "original" }, - }, - ], - isError: false, -}); -"# - .to_string(), - yield_time_ms: None, - ..execute_request("") - }, - ) - .await; - - assert_eq!( - response, - RuntimeResponse::Result { - cell_id: cell_id("1"), - content_items: Vec::new(), - error_text: Some( - "image expects a non-empty image URL string, an object with image_url and optional detail, or a raw MCP image block".to_string(), - ), - } - ); - } - - #[tokio::test] - async fn wait_reports_missing_cell_separately_from_runtime_results() { - let service = CodeModeService::new(); - - let response = service - .wait(WaitRequest { - cell_id: cell_id("missing"), - yield_time_ms: 1, - }) - .await - .unwrap(); - - assert_eq!( - response, - WaitOutcome::MissingCell(RuntimeResponse::Result { - cell_id: cell_id("missing"), - content_items: Vec::new(), - error_text: Some("exec cell missing not found".to_string()), - }) - ); - } - - #[tokio::test] - async fn terminate_waits_for_runtime_shutdown_before_responding() { - let inner = test_inner(); - let (event_tx, event_rx) = mpsc::unbounded_channel(); - let (control_tx, control_rx) = mpsc::unbounded_channel(); - let (initial_response_tx, initial_response_rx) = oneshot::channel(); - let (runtime_event_tx, _runtime_event_rx) = mpsc::unbounded_channel(); - let (runtime_tx, runtime_control_tx, runtime_terminate_handle) = spawn_runtime( - HashMap::new(), - ExecuteRequest { - source: "await new Promise(() => {})".to_string(), - yield_time_ms: None, - ..execute_request("") - }, - runtime_event_tx, - PendingRuntimeMode::Continue, - ) - .unwrap(); - - tokio::spawn(run_cell_control( - inner, - CellControlContext { - cell_id: cell_id("cell-1"), - runtime_tx: runtime_tx.clone(), - runtime_control_tx, - pending_mode: PendingRuntimeMode::Continue, - runtime_terminate_handle, - cancellation_token: tokio_util::sync::CancellationToken::new(), - }, - event_rx, - control_rx, - CellResponseSender::Runtime(initial_response_tx), - Some(/*initial_yield_time_ms*/ 60_000), - )); - - event_tx.send(RuntimeEvent::Started).unwrap(); - event_tx.send(RuntimeEvent::YieldRequested).unwrap(); - assert_eq!( - initial_response_rx.await.unwrap(), - Ok(RuntimeResponse::Yielded { - cell_id: cell_id("cell-1"), - content_items: Vec::new(), - }) - ); - - let (terminate_response_tx, terminate_response_rx) = oneshot::channel(); - control_tx - .send(CellControlCommand::Terminate { - response_tx: terminate_response_tx, - }) - .unwrap(); - let terminate_response = async { terminate_response_rx.await.unwrap() }; - tokio::pin!(terminate_response); - assert!( - tokio::time::timeout(Duration::from_millis(100), terminate_response.as_mut()) - .await - .is_err() - ); - - drop(event_tx); - - assert_eq!( - terminate_response.await, - Ok(RuntimeResponse::Terminated { - cell_id: cell_id("cell-1"), - content_items: Vec::new(), - }) - ); - - let _ = runtime_tx.send(RuntimeCommand::Terminate); - } -} +#[path = "service_tests.rs"] +mod tests; #[cfg(test)] #[path = "service_contract_tests.rs"] diff --git a/codex-rs/code-mode/src/service_contract_tests.rs b/codex-rs/code-mode/src/service_contract_tests.rs index 752c6cc6c..3d814e98c 100644 --- a/codex-rs/code-mode/src/service_contract_tests.rs +++ b/codex-rs/code-mode/src/service_contract_tests.rs @@ -1,7 +1,5 @@ -use std::collections::HashMap; use std::sync::Arc; use std::sync::atomic::AtomicBool; -use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; use std::time::Duration; @@ -9,7 +7,6 @@ use codex_protocol::ToolName; use pretty_assertions::assert_eq; use tokio::sync::Notify; use tokio::sync::mpsc; -use tokio::sync::oneshot; use tokio_util::sync::CancellationToken; use super::*; @@ -29,6 +26,7 @@ struct BlockingDelegate { events_tx: mpsc::UnboundedSender, notification_finished: AtomicBool, tool_finished: AtomicBool, + tool_release: Notify, } struct HeldNotificationDelegate { @@ -88,61 +86,6 @@ impl CodeModeSessionDelegate for HeldNotificationDelegate { } } -struct CellControlHarness { - event_tx: mpsc::UnboundedSender, - control_tx: mpsc::UnboundedSender, - initial_response_rx: oneshot::Receiver>, - task: tokio::task::JoinHandle<()>, - _runtime_event_rx: mpsc::UnboundedReceiver, -} - -fn spawn_cell_control_harness( - initial_yield_time_ms: Option, - delegate: Arc, -) -> CellControlHarness { - let (event_tx, event_rx) = mpsc::unbounded_channel(); - let (control_tx, control_rx) = mpsc::unbounded_channel(); - let (initial_response_tx, initial_response_rx) = oneshot::channel(); - let (runtime_event_tx, runtime_event_rx) = mpsc::unbounded_channel(); - let (runtime_tx, runtime_control_tx, runtime_terminate_handle) = spawn_runtime( - HashMap::new(), - execute_request("await new Promise(() => {});"), - runtime_event_tx, - PendingRuntimeMode::Continue, - ) - .unwrap(); - let inner = Arc::new(Inner { - stored_values: Mutex::new(HashMap::new()), - cells: Mutex::new(HashMap::new()), - delegate, - shutting_down: AtomicBool::new(false), - next_cell_id: AtomicU64::new(1), - }); - let task = tokio::spawn(run_cell_control( - inner, - CellControlContext { - cell_id: cell_id("1"), - runtime_tx, - runtime_control_tx, - pending_mode: PendingRuntimeMode::Continue, - runtime_terminate_handle, - cancellation_token: CancellationToken::new(), - }, - event_rx, - control_rx, - CellResponseSender::Runtime(initial_response_tx), - initial_yield_time_ms, - )); - - CellControlHarness { - event_tx, - control_tx, - initial_response_rx, - task, - _runtime_event_rx: runtime_event_rx, - } -} - impl BlockingDelegate { fn new() -> (Arc, mpsc::UnboundedReceiver) { let (events_tx, events_rx) = mpsc::unbounded_channel(); @@ -151,10 +94,15 @@ impl BlockingDelegate { events_tx, notification_finished: AtomicBool::new(false), tool_finished: AtomicBool::new(false), + tool_release: Notify::new(), }), events_rx, ) } + + fn release_tool(&self) { + self.tool_release.notify_one(); + } } impl CodeModeSessionDelegate for BlockingDelegate { @@ -165,10 +113,17 @@ impl CodeModeSessionDelegate for BlockingDelegate { ) -> ToolInvocationFuture<'a> { Box::pin(async move { let _ = self.events_tx.send(DelegateEvent::ToolStarted); - cancellation_token.cancelled().await; - self.tool_finished.store(true, Ordering::Release); - let _ = self.events_tx.send(DelegateEvent::ToolCancelled); - Err("cancelled".to_string()) + tokio::select! { + _ = self.tool_release.notified() => { + self.tool_finished.store(true, Ordering::Release); + Ok(serde_json::Value::Null) + } + _ = cancellation_token.cancelled() => { + self.tool_finished.store(true, Ordering::Release); + let _ = self.events_tx.send(DelegateEvent::ToolCancelled); + Err("cancelled".to_string()) + } + } }) } @@ -227,87 +182,15 @@ async fn next_event(events_rx: &mut mpsc::UnboundedReceiver) -> D .expect("delegate event channel closed") } -#[tokio::test] -async fn yield_timer_preempts_buffered_runtime_output() { - let harness = spawn_cell_control_harness( - Some(/*initial_yield_time_ms*/ 0), - Arc::new(NoopCodeModeSessionDelegate), - ); - harness.event_tx.send(RuntimeEvent::Started).unwrap(); - harness - .event_tx - .send(RuntimeEvent::ContentItem( - FunctionCallOutputContentItem::InputText { - text: "queued output".to_string(), - }, - )) - .unwrap(); - - assert_eq!( - harness.initial_response_rx.await.unwrap(), - Ok(RuntimeResponse::Yielded { - cell_id: cell_id("1"), - content_items: Vec::new(), - }) - ); - - let (termination_tx, termination_rx) = oneshot::channel(); - harness - .control_tx - .send(CellControlCommand::Terminate { - response_tx: termination_tx, - }) - .unwrap(); - drop(harness.event_tx); - assert_eq!( - termination_rx.await.unwrap(), - Ok(RuntimeResponse::Terminated { - cell_id: cell_id("1"), - content_items: vec![FunctionCallOutputContentItem::InputText { - text: "queued output".to_string(), - }], - }) - ); - harness.task.await.unwrap(); -} - -#[tokio::test] -async fn queued_termination_preempts_unobserved_runtime_completion() { - let harness = spawn_cell_control_harness( - Some(/*initial_yield_time_ms*/ 60_000), - Arc::new(NoopCodeModeSessionDelegate), - ); - harness - .event_tx - .send(RuntimeEvent::Result { - stored_value_writes: HashMap::new(), - error_text: None, - }) - .unwrap(); - let (termination_tx, termination_rx) = oneshot::channel(); - harness - .control_tx - .send(CellControlCommand::Terminate { - response_tx: termination_tx, - }) - .unwrap(); - - let terminated = Ok(RuntimeResponse::Terminated { - cell_id: cell_id("1"), - content_items: Vec::new(), - }); - assert_eq!(termination_rx.await.unwrap(), terminated.clone()); - assert_eq!(harness.initial_response_rx.await.unwrap(), terminated); - harness.task.await.unwrap(); -} - #[tokio::test] async fn yields_and_resumes() { let service = CodeModeService::new(); let cell = service - .execute(execute_request( - r#"text("before"); yield_control(); text("after");"#, - )) + .execute(ExecuteRequest { + source: r#"text("before"); yield_control(); text("after");"#.to_string(), + yield_time_ms: Some(60_000), + ..execute_request("") + }) .await .unwrap(); @@ -324,7 +207,7 @@ async fn yields_and_resumes() { service .wait(WaitRequest { cell_id: cell_id("1"), - yield_time_ms: 1, + yield_time_ms: 60_000, }) .await .unwrap(), @@ -340,35 +223,32 @@ async fn yields_and_resumes() { #[tokio::test] async fn returns_and_resumes_from_the_pending_frontier() { - let service = CodeModeService::new(); + let (delegate, mut events_rx) = BlockingDelegate::new(); + let service = CodeModeService::with_delegate(delegate.clone()); assert_eq!( service - .execute_to_pending(execute_request( - r#" -await new Promise((resolve) => setTimeout(resolve, 60_000)); + .execute_to_pending(ExecuteRequest { + enabled_tools: vec![blocking_tool()], + source: r#" +await tools.block({}); text("after"); -"#, - )) +"# + .to_string(), + yield_time_ms: Some(60_000), + ..execute_request("") + }) .await .unwrap(), ExecuteToPendingOutcome::Pending { cell_id: cell_id("1"), content_items: Vec::new(), - pending_tool_call_ids: Vec::new(), + pending_tool_call_ids: vec!["tool-1".to_string()], } ); - service - .inner - .cells - .lock() - .await - .get(&cell_id("1")) - .unwrap() - .runtime_tx - .send(RuntimeCommand::TimeoutFired { id: 1 }) - .unwrap(); + assert_eq!(next_event(&mut events_rx).await, DelegateEvent::ToolStarted); + delegate.release_tool(); assert_eq!( service @@ -391,55 +271,36 @@ text("after"); #[tokio::test] async fn observed_natural_completion_wins_over_termination() { - let (delegate, mut events_rx) = BlockingDelegate::new(); - let harness = - spawn_cell_control_harness(Some(/*initial_yield_time_ms*/ 60_000), delegate.clone()); - harness.event_tx.send(RuntimeEvent::YieldRequested).unwrap(); + let service = CodeModeService::new(); + let cell = service + .execute(execute_request( + r#"yield_control(); store("finished", true); text("done");"#, + )) + .await + .unwrap(); assert_eq!( - harness.initial_response_rx.await.unwrap(), - Ok(RuntimeResponse::Yielded { + cell.initial_response().await.unwrap(), + RuntimeResponse::Yielded { cell_id: cell_id("1"), content_items: Vec::new(), - }) + } ); - harness - .event_tx - .send(RuntimeEvent::ContentItem( - FunctionCallOutputContentItem::InputText { - text: "done".to_string(), - }, - )) - .unwrap(); - harness - .event_tx - .send(RuntimeEvent::Result { - stored_value_writes: HashMap::new(), - error_text: None, - }) - .unwrap(); - harness - .event_tx - .send(RuntimeEvent::Notify { - call_id: "notify-1".to_string(), - text: "completion observed".to_string(), - }) - .unwrap(); + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if service.inner.stored_values.lock().await.get("finished") + == Some(&serde_json::json!(true)) + { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .unwrap(); assert_eq!( - next_event(&mut events_rx).await, - DelegateEvent::NotificationStarted - ); - - let (termination_tx, termination_rx) = oneshot::channel(); - harness - .control_tx - .send(CellControlCommand::Terminate { - response_tx: termination_tx, - }) - .unwrap(); - assert_eq!( - termination_rx.await.unwrap(), - Ok(RuntimeResponse::Result { + service.terminate(cell_id("1")).await.unwrap(), + WaitOutcome::LiveCell(RuntimeResponse::Result { cell_id: cell_id("1"), content_items: vec![FunctionCallOutputContentItem::InputText { text: "done".to_string(), @@ -447,16 +308,6 @@ async fn observed_natural_completion_wins_over_termination() { error_text: None, }) ); - harness.task.await.unwrap(); - assert!(delegate.notification_finished.load(Ordering::Acquire)); - assert_eq!( - next_event(&mut events_rx).await, - DelegateEvent::NotificationCancelled - ); - assert_eq!( - next_event(&mut events_rx).await, - DelegateEvent::CellClosed(cell_id("1")) - ); } #[tokio::test] @@ -499,6 +350,36 @@ async fn termination_cancels_pending_callbacks_before_responding() { ); } +#[tokio::test] +async fn shutdown_cancels_notifications_while_natural_completion_is_draining() { + let (delegate, mut events_rx) = HeldNotificationDelegate::new(); + let service = Arc::new(CodeModeService::with_delegate(delegate.clone())); + service + .execute(execute_request(r#"notify("pending");"#)) + .await + .unwrap(); + + assert_eq!( + next_event(&mut events_rx).await, + DelegateEvent::NotificationStarted + ); + + let shutdown_service = Arc::clone(&service); + let shutdown = tokio::spawn(async move { shutdown_service.shutdown().await }); + + assert_eq!( + next_event(&mut events_rx).await, + DelegateEvent::NotificationCancelled + ); + delegate.release_notification(); + + assert_eq!(shutdown.await.unwrap(), Ok(())); + assert_eq!( + next_event(&mut events_rx).await, + DelegateEvent::CellClosed(cell_id("1")) + ); +} + #[tokio::test] async fn repeated_termination_is_rejected_while_callback_cleanup_is_pending() { let (delegate, mut events_rx) = HeldNotificationDelegate::new(); diff --git a/codex-rs/code-mode/src/service_tests.rs b/codex-rs/code-mode/src/service_tests.rs new file mode 100644 index 000000000..688bd89cd --- /dev/null +++ b/codex-rs/code-mode/src/service_tests.rs @@ -0,0 +1,972 @@ +use std::sync::Arc; +use std::sync::atomic::Ordering; +use std::time::Duration; + +use super::CellId; +use super::CodeModeNestedToolCall; +use super::CodeModeService; +use super::CodeModeSessionDelegate; +use super::NotificationFuture; +use super::ObserveMode; +use super::RuntimeResponse; +use super::ToolInvocationFuture; +use super::WaitOutcome; +use super::WaitRequest; +use super::WaitToPendingOutcome; +use super::WaitToPendingRequest; +use crate::CodeModeToolKind; +use crate::ExecuteRequest; +use crate::ExecuteToPendingOutcome; +use crate::FunctionCallOutputContentItem; +use crate::ToolDefinition; +use codex_protocol::ToolName; +use pretty_assertions::assert_eq; +use serde_json::Value as JsonValue; +use tokio::sync::Notify; +use tokio_util::sync::CancellationToken; + +#[derive(Default)] +struct ReleasableToolDelegate { + tool_release: Notify, +} + +impl ReleasableToolDelegate { + fn release_tool(&self) { + self.tool_release.notify_one(); + } +} + +impl CodeModeSessionDelegate for ReleasableToolDelegate { + fn invoke_tool<'a>( + &'a self, + _invocation: CodeModeNestedToolCall, + cancellation_token: CancellationToken, + ) -> ToolInvocationFuture<'a> { + Box::pin(async move { + tokio::select! { + _ = self.tool_release.notified() => Ok(JsonValue::Null), + _ = cancellation_token.cancelled() => Err("cancelled".to_string()), + } + }) + } + + fn notify<'a>( + &'a self, + _call_id: String, + _cell_id: CellId, + _text: String, + _cancellation_token: CancellationToken, + ) -> NotificationFuture<'a> { + Box::pin(async { Ok(()) }) + } + + fn cell_closed(&self, _cell_id: &CellId) {} +} + +fn execute_request(source: &str) -> ExecuteRequest { + ExecuteRequest { + tool_call_id: "call_1".to_string(), + enabled_tools: Vec::new(), + source: source.to_string(), + yield_time_ms: Some(1), + max_output_tokens: None, + } +} + +fn cell_id(value: &str) -> CellId { + CellId::new(value.to_string()) +} + +fn echo_tool() -> ToolDefinition { + ToolDefinition { + name: "echo".to_string(), + tool_name: ToolName::plain("echo"), + description: String::new(), + kind: CodeModeToolKind::Function, + input_schema: None, + output_schema: None, + } +} + +async fn execute(service: &CodeModeService, request: ExecuteRequest) -> RuntimeResponse { + service + .execute(request) + .await + .unwrap() + .initial_response() + .await + .unwrap() +} + +#[tokio::test] +async fn synchronous_exit_returns_successfully() { + let service = CodeModeService::new(); + + let response = execute( + &service, + ExecuteRequest { + source: r#"text("before"); exit(); text("after");"#.to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "before".to_string(), + }], + error_text: None, + } + ); +} + +#[tokio::test] +async fn stored_values_are_shared_between_cells_but_not_sessions() { + let first_session = CodeModeService::new(); + let second_session = CodeModeService::new(); + + let write_response = execute( + &first_session, + ExecuteRequest { + source: r#"store("key", "visible");"#.to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + let same_session = execute( + &first_session, + ExecuteRequest { + source: r#"text(String(load("key")));"#.to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + let other_session = execute( + &second_session, + ExecuteRequest { + source: r#"text(String(load("key")));"#.to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + assert_eq!( + write_response, + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: Vec::new(), + error_text: None, + } + ); + assert_eq!( + same_session, + RuntimeResponse::Result { + cell_id: cell_id("2"), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "visible".to_string(), + }], + error_text: None, + } + ); + assert_eq!( + other_session, + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "undefined".to_string(), + }], + error_text: None, + } + ); +} + +#[tokio::test] +async fn shutdown_interrupts_cpu_bound_cells() { + let service = CodeModeService::new(); + + let cell = service + .execute(ExecuteRequest { + source: "while (true) {}".to_string(), + ..execute_request("") + }) + .await + .unwrap(); + assert_eq!( + cell.initial_response().await.unwrap(), + RuntimeResponse::Yielded { + cell_id: cell_id("1"), + content_items: Vec::new(), + } + ); + + tokio::time::timeout(Duration::from_secs(1), service.shutdown()) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn start_cell_rejects_new_cell_after_shutdown_begins() { + let service = CodeModeService::new(); + service.inner.shutting_down.store(true, Ordering::Release); + + let error = service + .start_cell( + cell_id("late-cell"), + execute_request(""), + ObserveMode::YieldAfter(Duration::from_millis(1)), + ) + .await + .err() + .unwrap(); + + assert_eq!(error, "code mode session is shutting down".to_string()); + assert!(service.inner.cells.lock().await.is_empty()); +} + +#[tokio::test] +async fn execute_to_pending_returns_completed_for_synchronous_results() { + let service = CodeModeService::new(); + + let response = service + .execute_to_pending(ExecuteRequest { + source: r#"text("done");"#.to_string(), + yield_time_ms: Some(60_000), + ..execute_request("") + }) + .await + .unwrap(); + + assert_eq!( + response, + ExecuteToPendingOutcome::Completed(RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "done".to_string(), + }], + error_text: None, + }) + ); +} + +#[tokio::test] +async fn execute_to_pending_returns_once_the_runtime_is_quiescent() { + let service = CodeModeService::new(); + + let response = tokio::time::timeout( + Duration::from_secs(1), + service.execute_to_pending(ExecuteRequest { + source: r#"text("before"); await new Promise(() => {});"#.to_string(), + yield_time_ms: Some(60_000), + ..execute_request("") + }), + ) + .await + .unwrap() + .unwrap(); + + assert_eq!( + response, + ExecuteToPendingOutcome::Pending { + cell_id: cell_id("1"), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "before".to_string(), + }], + pending_tool_call_ids: Vec::new(), + } + ); + + let termination = service.terminate(cell_id("1")).await.unwrap(); + + assert_eq!( + termination, + WaitOutcome::LiveCell(RuntimeResponse::Terminated { + cell_id: cell_id("1"), + content_items: Vec::new(), + }) + ); +} + +#[tokio::test] +async fn execute_to_pending_identifies_tool_calls_in_paused_frontier() { + let service = CodeModeService::new(); + + let response = service + .execute_to_pending(ExecuteRequest { + enabled_tools: vec![echo_tool()], + source: r#" +await Promise.all([ + tools.echo({ value: "first" }), + tools.echo({ value: "second" }), +]); +"# + .to_string(), + yield_time_ms: Some(60_000), + ..execute_request("") + }) + .await + .unwrap(); + + assert_eq!( + response, + ExecuteToPendingOutcome::Pending { + cell_id: cell_id("1"), + content_items: Vec::new(), + pending_tool_call_ids: vec!["tool-1".to_string(), "tool-2".to_string()], + } + ); + + let termination = service.terminate(cell_id("1")).await.unwrap(); + + assert_eq!( + termination, + WaitOutcome::LiveCell(RuntimeResponse::Terminated { + cell_id: cell_id("1"), + content_items: Vec::new(), + }) + ); +} + +#[tokio::test] +async fn execute_to_pending_excludes_delayed_timeout_tool_calls_until_wait() { + let service = CodeModeService::new(); + + let initial_response = service + .execute_to_pending(ExecuteRequest { + enabled_tools: vec![echo_tool()], + source: r#" +setTimeout(() => { + tools.echo({ value: "delayed" }); +}, 1000); +await Promise.all([ + tools.echo({ value: "second" }), + tools.echo({ value: "third" }), +]); +"# + .to_string(), + yield_time_ms: Some(60_000), + ..execute_request("") + }) + .await + .unwrap(); + + assert_eq!( + initial_response, + ExecuteToPendingOutcome::Pending { + cell_id: cell_id("1"), + content_items: Vec::new(), + pending_tool_call_ids: vec!["tool-1".to_string(), "tool-2".to_string()], + } + ); + + tokio::time::sleep(Duration::from_millis(1100)).await; + + let resumed_response = tokio::time::timeout( + Duration::from_secs(1), + service.wait_to_pending(WaitToPendingRequest { + cell_id: cell_id("1"), + }), + ) + .await + .unwrap() + .unwrap(); + + assert_eq!( + resumed_response, + WaitToPendingOutcome::LiveCell(ExecuteToPendingOutcome::Pending { + cell_id: cell_id("1"), + content_items: Vec::new(), + pending_tool_call_ids: vec!["tool-3".to_string()], + }) + ); + + let termination = service.terminate(cell_id("1")).await.unwrap(); + + assert_eq!( + termination, + WaitOutcome::LiveCell(RuntimeResponse::Terminated { + cell_id: cell_id("1"), + content_items: Vec::new(), + }) + ); +} + +#[tokio::test] +async fn wait_to_pending_returns_after_resumed_runtime_becomes_quiescent_again() { + let delegate = Arc::new(ReleasableToolDelegate::default()); + let service = CodeModeService::with_delegate(delegate.clone()); + + let initial_response = service + .execute_to_pending(ExecuteRequest { + enabled_tools: vec![echo_tool()], + source: r#" +await tools.echo({}); +text("after"); +await new Promise(() => {}); +"# + .to_string(), + yield_time_ms: Some(60_000), + ..execute_request("") + }) + .await + .unwrap(); + + assert_eq!( + initial_response, + ExecuteToPendingOutcome::Pending { + cell_id: cell_id("1"), + content_items: Vec::new(), + pending_tool_call_ids: vec!["tool-1".to_string()], + } + ); + + delegate.release_tool(); + + let resumed_response = tokio::time::timeout( + Duration::from_secs(1), + service.wait_to_pending(WaitToPendingRequest { + cell_id: cell_id("1"), + }), + ) + .await + .unwrap() + .unwrap(); + + assert_eq!( + resumed_response, + WaitToPendingOutcome::LiveCell(ExecuteToPendingOutcome::Pending { + cell_id: cell_id("1"), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "after".to_string(), + }], + pending_tool_call_ids: Vec::new(), + }) + ); + + let termination = service.terminate(cell_id("1")).await.unwrap(); + + assert_eq!( + termination, + WaitOutcome::LiveCell(RuntimeResponse::Terminated { + cell_id: cell_id("1"), + content_items: Vec::new(), + }) + ); +} + +#[tokio::test] +async fn wait_to_pending_returns_completed_after_resumed_runtime_finishes() { + let delegate = Arc::new(ReleasableToolDelegate::default()); + let service = CodeModeService::with_delegate(delegate.clone()); + + let initial_response = service + .execute_to_pending(ExecuteRequest { + enabled_tools: vec![echo_tool()], + source: r#" +await tools.echo({}); +text("done"); +"# + .to_string(), + yield_time_ms: Some(60_000), + ..execute_request("") + }) + .await + .unwrap(); + + assert_eq!( + initial_response, + ExecuteToPendingOutcome::Pending { + cell_id: cell_id("1"), + content_items: Vec::new(), + pending_tool_call_ids: vec!["tool-1".to_string()], + } + ); + + delegate.release_tool(); + + let resumed_response = tokio::time::timeout( + Duration::from_secs(1), + service.wait_to_pending(WaitToPendingRequest { + cell_id: cell_id("1"), + }), + ) + .await + .unwrap() + .unwrap(); + + assert_eq!( + resumed_response, + WaitToPendingOutcome::LiveCell(ExecuteToPendingOutcome::Completed( + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "done".to_string(), + }], + error_text: None, + } + )) + ); +} + +#[tokio::test] +async fn v8_console_is_not_exposed_on_global_this() { + let service = CodeModeService::new(); + + let response = execute( + &service, + ExecuteRequest { + source: r#"text(String(Object.hasOwn(globalThis, "console")));"#.to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "false".to_string(), + }], + error_text: None, + } + ); +} + +#[tokio::test] +async fn date_locale_string_formats_with_icu_data() { + let service = CodeModeService::new(); + + let response = execute( + &service, + ExecuteRequest { + source: r#" +const value = new Date("2025-01-02T03:04:05Z") + .toLocaleString("fr-FR", { + weekday: "long", + month: "long", + day: "numeric", + hour: "2-digit", + minute: "2-digit", + second: "2-digit", + hour12: false, + timeZone: "UTC", + }); +text(value); +"# + .to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "jeudi 2 janvier \u{e0} 03:04:05".to_string(), + }], + error_text: None, + } + ); +} + +#[tokio::test] +async fn intl_date_time_format_formats_with_icu_data() { + let service = CodeModeService::new(); + + let response = execute( + &service, + ExecuteRequest { + source: r#" +const formatter = new Intl.DateTimeFormat("fr-FR", { + weekday: "long", + month: "long", + day: "numeric", + hour: "2-digit", + minute: "2-digit", + second: "2-digit", + hour12: false, + timeZone: "UTC", +}); +text(formatter.format(new Date("2025-01-02T03:04:05Z"))); +"# + .to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: vec![FunctionCallOutputContentItem::InputText { + text: "jeudi 2 janvier \u{e0} 03:04:05".to_string(), + }], + error_text: None, + } + ); +} + +#[tokio::test] +async fn output_helpers_return_undefined() { + let service = CodeModeService::new(); + + let response = execute( + &service, + ExecuteRequest { + source: r#" +const returnsUndefined = [ + text("first"), + image("data:image/png;base64,AAA"), + notify("ping"), +].map((value) => value === undefined); +text(JSON.stringify(returnsUndefined)); +"# + .to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: vec![ + FunctionCallOutputContentItem::InputText { + text: "first".to_string(), + }, + FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,AAA".to_string(), + detail: Some(crate::DEFAULT_IMAGE_DETAIL), + }, + FunctionCallOutputContentItem::InputText { + text: "[true,true,true]".to_string(), + }, + ], + error_text: None, + } + ); +} + +#[tokio::test] +async fn image_helper_accepts_raw_mcp_image_block_with_original_detail() { + let service = CodeModeService::new(); + + let response = execute( + &service, + ExecuteRequest { + source: r#" +image({ + type: "image", + data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", + mimeType: "image/png", + _meta: { "codex/imageDetail": "original" }, +}); +"# + .to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: vec![FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==".to_string(), + detail: Some(crate::ImageDetail::Original), + }], + error_text: None, + } + ); +} + +#[tokio::test] +async fn generated_image_helper_appends_image_and_output_hint() { + let service = CodeModeService::new(); + + let response = execute( + &service, + ExecuteRequest { + source: r#" +generatedImage({ + image_url: "data:image/png;base64,AAA", + output_hint: "generated image save hint", +}); +"# + .to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: vec![ + FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,AAA".to_string(), + detail: Some(crate::DEFAULT_IMAGE_DETAIL), + }, + FunctionCallOutputContentItem::InputText { + text: "generated image save hint".to_string(), + }, + ], + error_text: None, + } + ); +} + +#[tokio::test] +async fn image_helper_second_arg_overrides_explicit_object_detail() { + let service = CodeModeService::new(); + + let response = execute( + &service, + ExecuteRequest { + source: r#" +image( + { + image_url: "data:image/png;base64,AAA", + detail: "high", + }, + "original", +); +"# + .to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: vec![FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,AAA".to_string(), + detail: Some(crate::ImageDetail::Original), + }], + error_text: None, + } + ); +} + +#[tokio::test] +async fn image_helper_second_arg_overrides_raw_mcp_image_detail() { + let service = CodeModeService::new(); + + let response = execute( + &service, + ExecuteRequest { + source: r#" +image( + { + type: "image", + data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", + mimeType: "image/png", + _meta: { "codex/imageDetail": "original" }, + }, + "high", +); +"# + .to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: vec![FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==".to_string(), + detail: Some(crate::ImageDetail::High), + }], + error_text: None, + } + ); +} + +#[tokio::test] +async fn image_helper_accepts_low_detail() { + let service = CodeModeService::new(); + + let response = execute( + &service, + ExecuteRequest { + source: r#" +image({ + image_url: "data:image/png;base64,AAA", + detail: "low", +}); +"# + .to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: vec![FunctionCallOutputContentItem::InputImage { + image_url: "data:image/png;base64,AAA".to_string(), + detail: Some(crate::ImageDetail::Low), + }], + error_text: None, + } + ); +} + +#[tokio::test] +async fn image_helpers_reject_remote_urls() { + for image_url in [ + "http://example.com/image.jpg", + "https://example.com/image.jpg", + ] { + for source in [ + format!("image({image_url:?});"), + format!("generatedImage({{ image_url: {image_url:?} }});"), + ] { + let service = CodeModeService::new(); + + let response = execute( + &service, + ExecuteRequest { + source, + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: Vec::new(), + error_text: Some( + "Tool call failed: remote image URLs are not supported in tool outputs. Pass a base64 data URI instead".to_string(), + ), + } + ); + } + } +} + +#[tokio::test] +async fn image_helper_rejects_unsupported_detail() { + let service = CodeModeService::new(); + + let response = execute( + &service, + ExecuteRequest { + source: r#" +image({ + image_url: "data:image/png;base64,AAA", + detail: "medium", +}); +"# + .to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: Vec::new(), + error_text: Some("image detail must be one of: auto, low, high, original".to_string()), + } + ); +} + +#[tokio::test] +async fn image_helper_rejects_raw_mcp_result_container() { + let service = CodeModeService::new(); + + let response = execute( + &service, + ExecuteRequest { + source: r#" +image({ + content: [ + { + type: "image", + data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4z8DwHwAFAAH/iZk9HQAAAABJRU5ErkJggg==", + mimeType: "image/png", + _meta: { "codex/imageDetail": "original" }, + }, + ], + isError: false, +}); +"# + .to_string(), + yield_time_ms: None, + ..execute_request("") + }, + ) + .await; + + assert_eq!( + response, + RuntimeResponse::Result { + cell_id: cell_id("1"), + content_items: Vec::new(), + error_text: Some( + "image expects a non-empty image URL string, an object with image_url and optional detail, or a raw MCP image block".to_string(), + ), + } + ); +} + +#[tokio::test] +async fn wait_reports_missing_cell_separately_from_runtime_results() { + let service = CodeModeService::new(); + + let response = service + .wait(WaitRequest { + cell_id: cell_id("missing"), + yield_time_ms: 1, + }) + .await + .unwrap(); + + assert_eq!( + response, + WaitOutcome::MissingCell(RuntimeResponse::Result { + cell_id: cell_id("missing"), + content_items: Vec::new(), + error_text: Some("exec cell missing not found".to_string()), + }) + ); +}