diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 75c02adfb..17c51f58a 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2515,6 +2515,14 @@ dependencies = [ [[package]] name = "codex-code-mode-host" version = "0.0.0" +dependencies = [ + "anyhow", + "codex-code-mode", + "codex-code-mode-protocol", + "pretty_assertions", + "tokio", + "tokio-util", +] [[package]] name = "codex-code-mode-protocol" diff --git a/codex-rs/code-mode-host/Cargo.toml b/codex-rs/code-mode-host/Cargo.toml index a2c384d01..d37bd4101 100644 --- a/codex-rs/code-mode-host/Cargo.toml +++ b/codex-rs/code-mode-host/Cargo.toml @@ -8,5 +8,20 @@ license.workspace = true name = "codex-code-mode-host" path = "src/main.rs" +[lib] +doctest = false +name = "codex_code_mode_host" +path = "src/lib.rs" + [lints] workspace = true + +[dependencies] +anyhow = { workspace = true } +codex-code-mode = { workspace = true } +codex-code-mode-protocol = { workspace = true } +tokio = { workspace = true, features = ["io-std", "io-util", "macros", "rt", "sync", "time"] } +tokio-util = { workspace = true, features = ["rt"] } + +[dev-dependencies] +pretty_assertions = { workspace = true } diff --git a/codex-rs/code-mode-host/src/delegate.rs b/codex-rs/code-mode-host/src/delegate.rs new file mode 100644 index 000000000..50d45e7de --- /dev/null +++ b/codex-rs/code-mode-host/src/delegate.rs @@ -0,0 +1,85 @@ +use std::sync::Arc; + +use codex_code_mode_protocol::CellId; +use codex_code_mode_protocol::CodeModeNestedToolCall; +use codex_code_mode_protocol::CodeModeSessionDelegate; +use codex_code_mode_protocol::NotificationFuture; +use codex_code_mode_protocol::ToolInvocationFuture; +use codex_code_mode_protocol::host::DelegateRequest; +use codex_code_mode_protocol::host::DelegateResponse; +use codex_code_mode_protocol::host::SessionId; +use tokio_util::sync::CancellationToken; + +use crate::peer::HostPeer; + +pub(super) struct RemoteDelegate { + session_id: SessionId, + peer: Arc, +} + +impl RemoteDelegate { + pub(super) fn new(session_id: SessionId, peer: Arc) -> Self { + Self { session_id, peer } + } +} + +impl CodeModeSessionDelegate for RemoteDelegate { + fn invoke_tool<'a>( + &'a self, + invocation: CodeModeNestedToolCall, + cancellation_token: CancellationToken, + ) -> ToolInvocationFuture<'a> { + Box::pin(async move { + match self + .peer + .call( + self.session_id.clone(), + DelegateRequest::InvokeTool { + invocation: invocation.into(), + }, + cancellation_token, + ) + .await? + { + DelegateResponse::ToolResult { result } => Ok(result), + DelegateResponse::NotificationDelivered => { + Err("code-mode client returned an invalid tool result".to_string()) + } + } + }) + } + + fn notify<'a>( + &'a self, + call_id: String, + cell_id: CellId, + text: String, + cancellation_token: CancellationToken, + ) -> NotificationFuture<'a> { + Box::pin(async move { + match self + .peer + .call( + self.session_id.clone(), + DelegateRequest::Notify { + call_id, + cell_id: cell_id.into(), + text, + }, + cancellation_token, + ) + .await? + { + DelegateResponse::NotificationDelivered => Ok(()), + DelegateResponse::ToolResult { .. } => { + Err("code-mode client returned an invalid notification result".to_string()) + } + } + }) + } + + fn cell_closed(&self, cell_id: &CellId) { + self.peer + .close_cell(self.session_id.clone(), cell_id.clone()); + } +} diff --git a/codex-rs/code-mode-host/src/host_tests.rs b/codex-rs/code-mode-host/src/host_tests.rs new file mode 100644 index 000000000..b189b89e0 --- /dev/null +++ b/codex-rs/code-mode-host/src/host_tests.rs @@ -0,0 +1,486 @@ +use codex_code_mode_protocol::host::Capability; +use codex_code_mode_protocol::host::CapabilitySet; +use codex_code_mode_protocol::host::ClientHello; +use codex_code_mode_protocol::host::ClientToHost; +use codex_code_mode_protocol::host::EncodedFrame; +use codex_code_mode_protocol::host::FramedReader; +use codex_code_mode_protocol::host::FramedWriter; +use codex_code_mode_protocol::host::HandshakeRejectReason; +use codex_code_mode_protocol::host::HostHello; +use codex_code_mode_protocol::host::HostRequest; +use codex_code_mode_protocol::host::HostResponse; +use codex_code_mode_protocol::host::HostToClient; +use codex_code_mode_protocol::host::ProtocolVersion; +use codex_code_mode_protocol::host::RequestId; +use codex_code_mode_protocol::host::SessionId; +use codex_code_mode_protocol::host::SupportedProtocolVersions; +use codex_code_mode_protocol::host::WireExecuteRequest; +use codex_code_mode_protocol::host::WireResult; +use pretty_assertions::assert_eq; +use tokio::sync::Semaphore; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; +use tokio_util::task::TaskTracker; + +use super::HostState; +use super::MAX_ACTIVE_CELLS; +use super::MAX_IN_FLIGHT_REQUESTS; +use super::MAX_RECENT_REQUEST_IDS; +use super::RequestKind; +use super::RequestRegistry; +use super::SeenSessionIds; +use super::peer::HostPeer; +use super::run; + +fn client_hello( + versions: impl IntoIterator, + required_capabilities: CapabilitySet, +) -> ClientToHost { + ClientToHost::ClientHello( + ClientHello::new( + SupportedProtocolVersions::try_new(versions).expect("supported versions"), + required_capabilities, + CapabilitySet::empty(), + ) + .expect("client hello"), + ) +} + +fn session_id(value: &str) -> SessionId { + SessionId::new(value).expect("session ID") +} + +fn request_id(value: i64) -> RequestId { + RequestId::new(value) +} + +async fn decode_frame(frame: EncodedFrame) -> HostToClient { + let (reader, writer) = tokio::io::duplex(/*max_buf_size*/ 4096); + let writer = tokio::spawn(async move { + FramedWriter::new(writer) + .write_frame(&frame) + .await + .expect("write encoded frame"); + }); + let message = FramedReader::new(reader) + .read() + .await + .expect("read encoded frame") + .expect("encoded frame message"); + writer.await.expect("frame writer task"); + message +} + +fn execute_request(source: &str) -> WireExecuteRequest { + WireExecuteRequest { + tool_call_id: "call-1".to_string(), + enabled_tools: Vec::new(), + source: source.to_string(), + yield_time_ms: Some(60_000), + max_output_tokens: Some(1_000), + } +} + +#[tokio::test] +async fn handshake_and_multiple_session_lifecycles_are_ordered() { + let (host_stream, client_stream) = tokio::io::duplex(/*max_buf_size*/ 4096); + let (host_reader, host_writer) = tokio::io::split(host_stream); + let (client_reader, client_writer) = tokio::io::split(client_stream); + let host = tokio::spawn(run(host_reader, host_writer)); + let mut reader = FramedReader::new(client_reader); + let mut writer = FramedWriter::new(client_writer); + + writer + .write(&client_hello([ProtocolVersion::V1], CapabilitySet::empty())) + .await + .expect("write hello"); + assert_eq!( + reader.read::().await.expect("read hello"), + Some(HostToClient::HostHello(HostHello::new( + ProtocolVersion::V1, + CapabilitySet::empty(), + ))) + ); + + for (request_id, id) in [ + (request_id(/*value*/ 1), "session-1"), + (request_id(/*value*/ 2), "session-2"), + ] { + writer + .write(&ClientToHost::Request { + id: request_id, + request: HostRequest::OpenSession { + session_id: session_id(id), + }, + }) + .await + .expect("open session"); + assert_eq!( + reader.read::().await.expect("session ready"), + Some(HostToClient::Response { + id: request_id, + result: WireResult::Ok { + value: HostResponse::SessionReady { + session_id: session_id(id), + }, + }, + }) + ); + } + + for (request_id, id) in [ + (request_id(/*value*/ 3), "session-1"), + (request_id(/*value*/ 4), "session-2"), + ] { + writer + .write(&ClientToHost::Request { + id: request_id, + request: HostRequest::ShutdownSession { + session_id: session_id(id), + }, + }) + .await + .expect("shutdown session"); + assert_eq!( + reader.read::().await.expect("session closed"), + Some(HostToClient::Response { + id: request_id, + result: WireResult::Ok { + value: HostResponse::SessionClosed { + session_id: session_id(id), + }, + }, + }) + ); + } + + drop(writer); + drop(reader); + host.await.expect("host task").expect("host connection"); +} + +#[tokio::test] +async fn incompatible_or_invalid_handshake_is_rejected() { + let (host_stream, client_stream) = tokio::io::duplex(/*max_buf_size*/ 1024); + let (host_reader, host_writer) = tokio::io::split(host_stream); + let (client_reader, client_writer) = tokio::io::split(client_stream); + let host = tokio::spawn(run(host_reader, host_writer)); + let mut reader = FramedReader::new(client_reader); + let mut writer = FramedWriter::new(client_writer); + let version_two = ProtocolVersion::new(/*value*/ 2).expect("protocol version"); + + writer + .write(&client_hello([version_two], CapabilitySet::empty())) + .await + .expect("write hello"); + assert_eq!( + reader.read::().await.expect("rejection"), + Some(HostToClient::HandshakeRejected { + reason: HandshakeRejectReason::NoCompatibleVersion { + supported_versions: SupportedProtocolVersions::try_new([ProtocolVersion::V1]) + .expect("host versions"), + }, + }) + ); + host.await.expect("host task").expect("host connection"); + + let (host_stream, client_stream) = tokio::io::duplex(/*max_buf_size*/ 1024); + let (host_reader, host_writer) = tokio::io::split(host_stream); + let (client_reader, client_writer) = tokio::io::split(client_stream); + let host = tokio::spawn(run(host_reader, host_writer)); + let mut reader = FramedReader::new(client_reader); + let mut writer = FramedWriter::new(client_writer); + writer + .write(&ClientToHost::Request { + id: request_id(/*value*/ 1), + request: HostRequest::OpenSession { + session_id: session_id("session-1"), + }, + }) + .await + .expect("write invalid first message"); + assert_eq!( + reader.read::().await.expect("rejection"), + Some(HostToClient::HandshakeRejected { + reason: HandshakeRejectReason::InvalidHello { + message: "first message must be connection/hello".to_string(), + }, + }) + ); + host.await.expect("host task").expect("host connection"); +} + +#[tokio::test] +async fn unsupported_required_capability_is_rejected() { + let (host_stream, client_stream) = tokio::io::duplex(/*max_buf_size*/ 1024); + let (host_reader, host_writer) = tokio::io::split(host_stream); + let (client_reader, client_writer) = tokio::io::split(client_stream); + let host = tokio::spawn(run(host_reader, host_writer)); + let mut reader = FramedReader::new(client_reader); + let mut writer = FramedWriter::new(client_writer); + let capability = Capability::new("required").expect("capability"); + + writer + .write(&client_hello( + [ProtocolVersion::V1], + CapabilitySet::try_new([capability.clone()]).expect("capabilities"), + )) + .await + .expect("write hello"); + assert_eq!( + reader.read::().await.expect("rejection"), + Some(HostToClient::HandshakeRejected { + reason: HandshakeRejectReason::MissingRequiredCapability { capability }, + }) + ); + host.await.expect("host task").expect("host connection"); +} + +#[tokio::test] +async fn session_id_cannot_be_reused_after_shutdown() { + let (host_stream, client_stream) = tokio::io::duplex(/*max_buf_size*/ 2048); + let (host_reader, host_writer) = tokio::io::split(host_stream); + let (client_reader, client_writer) = tokio::io::split(client_stream); + let host = tokio::spawn(run(host_reader, host_writer)); + let mut reader = FramedReader::new(client_reader); + let mut writer = FramedWriter::new(client_writer); + writer + .write(&client_hello([ProtocolVersion::V1], CapabilitySet::empty())) + .await + .expect("write hello"); + reader + .read::() + .await + .expect("read hello") + .expect("host hello"); + + let id = session_id("session-1"); + for (request_id, request) in [ + ( + request_id(/*value*/ 1), + HostRequest::OpenSession { + session_id: id.clone(), + }, + ), + ( + request_id(/*value*/ 2), + HostRequest::ShutdownSession { + session_id: id.clone(), + }, + ), + ] { + writer + .write(&ClientToHost::Request { + id: request_id, + request, + }) + .await + .expect("session request"); + reader + .read::() + .await + .expect("session response") + .expect("session response message"); + } + writer + .write(&ClientToHost::Request { + id: request_id(/*value*/ 3), + request: HostRequest::OpenSession { session_id: id }, + }) + .await + .expect("reuse session ID"); + assert_eq!( + reader.read::().await.expect("reuse response"), + Some(HostToClient::Response { + id: request_id(/*value*/ 3), + result: WireResult::Err { + message: "code-mode session ID `session-1` was reused".to_string(), + }, + }) + ); + drop(writer); + drop(reader); + host.await.expect("host task").expect("host connection"); +} + +#[test] +fn request_cancellation_tombstones_are_bounded() { + let mut requests = RequestRegistry::default(); + let duplicate = request_id(/*value*/ -1); + requests + .start(duplicate, RequestKind::OpenSession) + .expect("start duplicate probe"); + assert!(requests.start(duplicate, RequestKind::OpenSession).is_err()); + requests.finish(duplicate); + for value in 1..=MAX_RECENT_REQUEST_IDS as i64 + 100 { + let id = request_id(value); + requests + .start(id, RequestKind::Wait) + .expect("start request"); + requests.cancel(id); + requests.finish(id); + } + for value in 10_000..20_000 { + requests.cancel(request_id(value)); + } + + assert!(requests.active.is_empty()); + assert_eq!(requests.recent.len(), MAX_RECENT_REQUEST_IDS); + assert_eq!(requests.recent_order.len(), MAX_RECENT_REQUEST_IDS); +} + +#[tokio::test] +async fn request_task_panic_disconnects_host() { + let (outgoing_tx, _outgoing_rx) = mpsc::channel(/*max_capacity*/ 1); + let peer = Arc::new(HostPeer::new(outgoing_tx)); + let state = HostState { + sessions: Mutex::new(HashMap::new()), + seen_session_ids: Mutex::new(SeenSessionIds::default()), + requests: Mutex::new(RequestRegistry::default()), + request_tasks: TaskTracker::new(), + request_permits: Arc::new(Semaphore::new(MAX_IN_FLIGHT_REQUESTS)), + active_cell_permits: Arc::new(Semaphore::new(MAX_ACTIVE_CELLS)), + closing: AtomicBool::new(false), + peer: Arc::clone(&peer), + }; + let task = state.request_tasks.spawn(async { + panic!("request panic probe"); + }); + state.supervise_request_task(task); + + tokio::time::timeout(Duration::from_secs(1), peer.disconnected()) + .await + .expect("request panic should disconnect host"); + assert!( + peer.failure() + .expect("request failure") + .contains("request task failed") + ); +} + +#[tokio::test] +async fn execute_request_id_remains_active_until_initial_response() { + let (outgoing_tx, mut outgoing_rx) = mpsc::channel(/*max_capacity*/ 4); + let peer = Arc::new(HostPeer::new(outgoing_tx)); + let state = Arc::new(HostState { + sessions: Mutex::new(HashMap::new()), + seen_session_ids: Mutex::new(SeenSessionIds::default()), + requests: Mutex::new(RequestRegistry::default()), + request_tasks: TaskTracker::new(), + request_permits: Arc::new(Semaphore::new(MAX_IN_FLIGHT_REQUESTS)), + active_cell_permits: Arc::new(Semaphore::new(MAX_ACTIVE_CELLS)), + closing: AtomicBool::new(false), + peer, + }); + let session_id = session_id("session-1"); + state + .open_session(session_id.clone()) + .expect("open session"); + let request_id = request_id(/*value*/ 1); + + state + .spawn_request( + request_id, + HostRequest::Execute { + session_id: session_id.clone(), + request: execute_request("await new Promise(() => {});"), + }, + ) + .expect("spawn execute request"); + let started = decode_frame(outgoing_rx.recv().await.expect("execution started frame")).await; + let HostToClient::Response { + id, + result: + WireResult::Ok { + value: HostResponse::ExecutionStarted { cell_id }, + }, + } = started + else { + panic!("expected execution started response"); + }; + assert_eq!(id, request_id); + assert!( + state + .requests + .lock() + .unwrap_or_else(PoisonError::into_inner) + .active + .contains_key(&request_id) + ); + + state + .session(&session_id) + .expect("session") + .terminate(cell_id.into()) + .await + .expect("terminate cell"); + state.disconnect().await; +} + +#[tokio::test] +async fn active_cell_limit_rejects_execute_without_disconnecting() { + let (outgoing_tx, mut outgoing_rx) = mpsc::channel(/*max_capacity*/ 1); + let peer = Arc::new(HostPeer::new(outgoing_tx)); + let state = HostState { + sessions: Mutex::new(HashMap::new()), + seen_session_ids: Mutex::new(SeenSessionIds::default()), + requests: Mutex::new(RequestRegistry::default()), + request_tasks: TaskTracker::new(), + request_permits: Arc::new(Semaphore::new(MAX_IN_FLIGHT_REQUESTS)), + active_cell_permits: Arc::new(Semaphore::new(/*permits*/ 0)), + closing: AtomicBool::new(false), + peer: Arc::clone(&peer), + }; + let session_id = session_id("session-1"); + state + .open_session(session_id.clone()) + .expect("open session"); + let request_id = request_id(/*value*/ 1); + + state + .handle_request( + request_id, + HostRequest::Execute { + session_id, + request: execute_request("text(\"hello\");"), + }, + CancellationToken::new(), + ) + .await; + + assert_eq!( + decode_frame(outgoing_rx.recv().await.expect("execute response frame")).await, + HostToClient::Response { + id: request_id, + result: WireResult::Err { + message: "code-mode host has too many active cells".to_string(), + }, + } + ); + assert!(!peer.is_disconnected()); + state.disconnect().await; +} + +#[tokio::test] +async fn cell_forwarding_panic_disconnects_host() { + let (outgoing_tx, _outgoing_rx) = mpsc::channel(/*max_capacity*/ 1); + let peer = Arc::new(HostPeer::new(outgoing_tx)); + peer.spawn_critical("cell forwarding", async { + panic!("cell forwarding panic probe"); + }); + + tokio::time::timeout(Duration::from_secs(1), peer.disconnected()) + .await + .expect("cell panic should disconnect host"); + assert!( + peer.failure() + .expect("cell failure") + .contains("cell forwarding task failed") + ); +} +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::Mutex; +use std::sync::PoisonError; +use std::sync::atomic::AtomicBool; +use std::time::Duration; diff --git a/codex-rs/code-mode-host/src/lib.rs b/codex-rs/code-mode-host/src/lib.rs new file mode 100644 index 000000000..d2c1f507b --- /dev/null +++ b/codex-rs/code-mode-host/src/lib.rs @@ -0,0 +1,605 @@ +use std::collections::HashMap; +use std::collections::HashSet; +use std::collections::VecDeque; +use std::sync::Arc; +use std::sync::Mutex; +use std::sync::PoisonError; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::time::Duration; + +use anyhow::Context; +use anyhow::Result; +use codex_code_mode::InProcessCodeModeSession; +use codex_code_mode_protocol::host::CapabilitySet; +use codex_code_mode_protocol::host::ClientToHost; +use codex_code_mode_protocol::host::EncodedFrame; +use codex_code_mode_protocol::host::FramedReader; +use codex_code_mode_protocol::host::FramedWriter; +use codex_code_mode_protocol::host::HandshakeRejectReason; +use codex_code_mode_protocol::host::HostHello; +use codex_code_mode_protocol::host::HostRequest; +use codex_code_mode_protocol::host::HostResponse; +use codex_code_mode_protocol::host::HostToClient; +use codex_code_mode_protocol::host::ProtocolVersion; +use codex_code_mode_protocol::host::RequestId; +use codex_code_mode_protocol::host::SessionId; +use codex_code_mode_protocol::host::SupportedProtocolVersions; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::sync::Semaphore; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; +use tokio_util::task::TaskTracker; + +use self::delegate::RemoteDelegate; +use self::peer::HostPeer; + +mod delegate; +mod peer; + +const MAX_IN_FLIGHT_REQUESTS: usize = 256; +const MAX_ACTIVE_CELLS: usize = 128; +const MAX_RECENT_REQUEST_IDS: usize = 4096; +const MAX_RECENT_SESSION_IDS: usize = 4096; +const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); + +/// Runs one code-mode host connection over the process standard streams. +pub async fn run_stdio() -> Result<()> { + run(tokio::io::stdin(), tokio::io::stdout()).await +} + +/// Runs one code-mode host connection over an ordered input/output pair. +async fn run(reader: R, writer: W) -> Result<()> +where + R: AsyncRead + Send + Unpin + 'static, + W: AsyncWrite + Send + Unpin + 'static, +{ + let mut reader = FramedReader::new(reader); + let mut writer = FramedWriter::new(writer); + if !negotiate(&mut reader, &mut writer).await? { + return Ok(()); + } + + let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(/*max_capacity*/ 128); + let peer = Arc::new(HostPeer::new(outgoing_tx)); + let state = Arc::new(HostState { + sessions: Mutex::new(HashMap::new()), + seen_session_ids: Mutex::new(SeenSessionIds::default()), + requests: Mutex::new(RequestRegistry::default()), + request_tasks: TaskTracker::new(), + request_permits: Arc::new(Semaphore::new(MAX_IN_FLIGHT_REQUESTS)), + active_cell_permits: Arc::new(Semaphore::new(MAX_ACTIVE_CELLS)), + closing: AtomicBool::new(false), + peer: Arc::clone(&peer), + }); + let writer_disconnected = peer.disconnection_token(); + let writer_task = tokio::spawn(async move { + loop { + tokio::select! { + _ = writer_disconnected.cancelled() => return Ok::<(), anyhow::Error>(()), + frame = outgoing_rx.recv() => { + let Some(frame) = frame else { + return Ok(()); + }; + if let Err(err) = writer.write_frame(&frame).await { + return Err( + anyhow::Error::new(err) + .context("failed to write code-mode host message") + ); + } + } + } + } + }); + let writer_peer = Arc::clone(&peer); + let writer_supervisor = tokio::spawn(async move { + match writer_task.await { + Ok(Ok(())) if !writer_peer.is_disconnected() => { + writer_peer.fail("code-mode writer task exited unexpectedly".to_string()); + } + Ok(Ok(())) => {} + Ok(Err(err)) => { + writer_peer.fail(format!("code-mode writer task failed: {err:#}")); + } + Err(err) => { + writer_peer.fail(format!("code-mode writer task failed: {err}")); + } + } + }); + + let input_result = async { + loop { + let message = tokio::select! { + _ = peer.disconnected() => break, + message = reader.read::() => message + .context("failed to read code-mode client message")?, + }; + let Some(message) = message else { + break; + }; + match message { + ClientToHost::ClientHello(_) => { + anyhow::bail!("received a second code-mode client hello"); + } + ClientToHost::Request { id, request } => { + state.spawn_request(id, request)?; + } + ClientToHost::CancelRequest { id } => { + state.cancel_request(id); + } + ClientToHost::DelegateResponse { id, result } => { + peer.complete(id, result.into_result()).await; + } + } + } + Ok::<(), anyhow::Error>(()) + } + .await; + + peer.disconnect(); + if tokio::time::timeout(SHUTDOWN_TIMEOUT, state.disconnect()) + .await + .is_err() + { + peer.fail("timed out shutting down code-mode host state".to_string()); + } + drop(state); + tokio::time::timeout(SHUTDOWN_TIMEOUT, writer_supervisor) + .await + .context("timed out supervising code-mode writer task")? + .context("code-mode writer supervisor task failed")?; + let failure = peer.failure(); + drop(peer); + input_result?; + if let Some(failure) = failure { + anyhow::bail!(failure); + } + Ok(()) +} + +async fn negotiate(reader: &mut FramedReader, writer: &mut FramedWriter) -> Result +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let Some(first_message) = reader + .read::() + .await + .context("failed to read code-mode client hello")? + else { + return Ok(false); + }; + let ClientToHost::ClientHello(client_hello) = first_message else { + writer + .write(&HostToClient::HandshakeRejected { + reason: HandshakeRejectReason::InvalidHello { + message: "first message must be connection/hello".to_string(), + }, + }) + .await + .context("failed to reject invalid code-mode client hello")?; + return Ok(false); + }; + + let supported_versions = SupportedProtocolVersions::try_new([ProtocolVersion::V1])?; + if !client_hello + .supported_versions() + .contains(ProtocolVersion::V1) + { + writer + .write(&HostToClient::HandshakeRejected { + reason: HandshakeRejectReason::NoCompatibleVersion { supported_versions }, + }) + .await + .context("failed to reject incompatible code-mode client")?; + return Ok(false); + } + + let host_capabilities = CapabilitySet::empty(); + if let Some(capability) = client_hello + .required_capabilities() + .iter() + .find(|capability| !host_capabilities.contains(capability)) + { + writer + .write(&HostToClient::HandshakeRejected { + reason: HandshakeRejectReason::MissingRequiredCapability { + capability: capability.clone(), + }, + }) + .await + .context("failed to reject unsupported code-mode capability")?; + return Ok(false); + } + + writer + .write(&HostToClient::HostHello(HostHello::new( + ProtocolVersion::V1, + host_capabilities, + ))) + .await + .context("failed to write code-mode host hello")?; + Ok(true) +} + +struct HostState { + sessions: Mutex>>, + seen_session_ids: Mutex, + requests: Mutex, + request_tasks: TaskTracker, + request_permits: Arc, + active_cell_permits: Arc, + closing: AtomicBool, + peer: Arc, +} + +impl HostState { + fn spawn_request( + self: &Arc, + request_id: RequestId, + request: HostRequest, + ) -> Result<(), anyhow::Error> { + let cancellation = self + .requests + .lock() + .unwrap_or_else(PoisonError::into_inner) + .start(request_id, RequestKind::from(&request))?; + let Ok(permit) = Arc::clone(&self.request_permits).try_acquire_owned() else { + self.respond( + request_id, + Err("code-mode host has too many in-flight requests".to_string()), + ); + self.finish_request(request_id); + return Ok(()); + }; + let state = Arc::clone(self); + let request_task = self.request_tasks.spawn(async move { + let _permit = permit; + state + .handle_request(request_id, request, cancellation) + .await; + state.finish_request(request_id); + }); + self.supervise_request_task(request_task); + Ok(()) + } + + fn supervise_request_task(&self, task: tokio::task::JoinHandle<()>) { + let peer = Arc::clone(&self.peer); + tokio::spawn(async move { + if let Err(err) = task.await { + peer.fail(format!("code-mode request task failed: {err}")); + } + }); + } + + async fn handle_request( + &self, + request_id: RequestId, + request: HostRequest, + cancellation: CancellationToken, + ) { + if self.closing.load(Ordering::Acquire) { + self.respond( + request_id, + Err("code-mode host is shutting down".to_string()), + ); + return; + } + match request { + HostRequest::OpenSession { session_id } => { + let result = self + .open_session(session_id.clone()) + .map(|()| HostResponse::SessionReady { session_id }); + self.respond(request_id, result); + } + HostRequest::Execute { + session_id, + request, + } => { + if cancellation.is_cancelled() { + self.respond(request_id, Err("code-mode request cancelled".to_string())); + return; + } + let request = match request.try_into() { + Ok(request) => request, + Err(err) => { + self.respond( + request_id, + Err(format!("invalid code-mode execute request: {err}")), + ); + return; + } + }; + let session = match self.session(&session_id) { + Ok(session) => session, + Err(err) => { + self.respond(request_id, Err(err)); + return; + } + }; + let Ok(active_cell_permit) = + Arc::clone(&self.active_cell_permits).try_acquire_owned() + else { + self.respond( + request_id, + Err("code-mode host has too many active cells".to_string()), + ); + return; + }; + let result = session.execute(request).await; + match result { + Ok(started) => { + let cell_id = started.cell_id.clone(); + self.respond( + request_id, + Ok(HostResponse::ExecutionStarted { + cell_id: cell_id.into(), + }), + ); + let initial_response_sent = self.peer.start_cell( + session_id, + request_id, + started, + active_cell_permit, + ); + let _ = initial_response_sent.await; + } + Err(err) => self.respond(request_id, Err(err)), + } + } + HostRequest::Wait { + session_id, + request, + } => { + let result = match self.session(&session_id) { + Ok(session) => { + tokio::select! { + biased; + _ = cancellation.cancelled() => { + Err("code-mode request cancelled".to_string()) + } + result = session.wait(request.into()) => result.map(|outcome| { + HostResponse::WaitCompleted { + outcome: outcome.into(), + } + }), + } + } + Err(err) => Err(err), + }; + self.respond(request_id, result); + } + HostRequest::Terminate { + session_id, + cell_id, + } => { + let result = match self.session(&session_id) { + Ok(session) => session.terminate(cell_id.into()).await.map(|outcome| { + HostResponse::WaitCompleted { + outcome: outcome.into(), + } + }), + Err(err) => Err(err), + }; + self.respond(request_id, result); + } + HostRequest::ShutdownSession { session_id } => { + let session = self + .sessions + .lock() + .unwrap_or_else(PoisonError::into_inner) + .remove(&session_id); + let result = match session { + Some(session) => match session.shutdown().await { + Ok(()) => { + self.peer.wait_for_session_cells(&session_id).await; + Ok(HostResponse::SessionClosed { session_id }) + } + Err(err) => Err(err), + }, + None => Err(format!("unknown code-mode session {session_id}")), + }; + self.respond(request_id, result); + } + } + } + + fn open_session(&self, session_id: SessionId) -> Result<(), String> { + let mut sessions = self.sessions.lock().unwrap_or_else(PoisonError::into_inner); + if sessions.contains_key(&session_id) { + return Err(format!( + "code-mode session ID `{session_id}` is already open" + )); + } + if self.closing.load(Ordering::Acquire) { + return Err("code-mode host is shutting down".to_string()); + } + if !self + .seen_session_ids + .lock() + .unwrap_or_else(PoisonError::into_inner) + .remember(session_id.clone()) + { + return Err(format!("code-mode session ID `{session_id}` was reused")); + } + let delegate = Arc::new(RemoteDelegate::new( + session_id.clone(), + Arc::clone(&self.peer), + )); + let peer = Arc::downgrade(&self.peer); + let task_failure_handler = Arc::new(move |reason| { + if let Some(peer) = peer.upgrade() { + peer.fail(reason); + } + }); + sessions.insert( + session_id, + Arc::new( + InProcessCodeModeSession::with_delegate_and_task_failure_handler( + delegate, + task_failure_handler, + ), + ), + ); + Ok(()) + } + + fn session(&self, session_id: &SessionId) -> Result, String> { + self.sessions + .lock() + .unwrap_or_else(PoisonError::into_inner) + .get(session_id) + .cloned() + .ok_or_else(|| format!("unknown code-mode session {session_id}")) + } + + fn respond(&self, id: RequestId, result: Result) { + self.peer.respond(id, result); + } + + fn cancel_request(&self, request_id: RequestId) { + self.requests + .lock() + .unwrap_or_else(PoisonError::into_inner) + .cancel(request_id); + } + + fn finish_request(&self, request_id: RequestId) { + self.requests + .lock() + .unwrap_or_else(PoisonError::into_inner) + .finish(request_id); + } + + async fn disconnect(&self) { + self.closing.store(true, Ordering::Release); + self.requests + .lock() + .unwrap_or_else(PoisonError::into_inner) + .cancel_all(); + self.request_tasks.close(); + self.request_tasks.wait().await; + let sessions = self + .sessions + .lock() + .unwrap_or_else(PoisonError::into_inner) + .drain() + .map(|(_, session)| session) + .collect::>(); + for session in sessions { + let _ = session.shutdown().await; + } + } +} + +#[derive(Clone, Copy)] +enum RequestKind { + OpenSession, + Execute, + Wait, + Terminate, + ShutdownSession, +} + +impl RequestKind { + fn from(request: &HostRequest) -> Self { + match request { + HostRequest::OpenSession { .. } => Self::OpenSession, + HostRequest::Execute { .. } => Self::Execute, + HostRequest::Wait { .. } => Self::Wait, + HostRequest::Terminate { .. } => Self::Terminate, + HostRequest::ShutdownSession { .. } => Self::ShutdownSession, + } + } + + fn is_cancellable(self) -> bool { + matches!(self, Self::Execute | Self::Wait) + } +} + +struct ActiveRequest { + kind: RequestKind, + cancellation: CancellationToken, +} + +#[derive(Default)] +struct RequestRegistry { + active: HashMap, + recent: HashSet, + recent_order: VecDeque, +} + +impl RequestRegistry { + fn start( + &mut self, + request_id: RequestId, + kind: RequestKind, + ) -> Result { + if self.active.contains_key(&request_id) || self.recent.contains(&request_id) { + anyhow::bail!("duplicate code-mode request ID {request_id:?}"); + } + let cancellation = CancellationToken::new(); + self.active.insert( + request_id, + ActiveRequest { + kind, + cancellation: cancellation.clone(), + }, + ); + Ok(cancellation) + } + + fn cancel(&self, request_id: RequestId) { + if let Some(request) = self.active.get(&request_id) + && request.kind.is_cancellable() + { + request.cancellation.cancel(); + } + } + + fn finish(&mut self, request_id: RequestId) { + if self.active.remove(&request_id).is_none() { + return; + } + self.recent.insert(request_id); + self.recent_order.push_back(request_id); + while self.recent_order.len() > MAX_RECENT_REQUEST_IDS { + if let Some(expired) = self.recent_order.pop_front() { + self.recent.remove(&expired); + } + } + } + + fn cancel_all(&self) { + for request in self.active.values() { + request.cancellation.cancel(); + } + } +} + +#[derive(Default)] +struct SeenSessionIds { + ids: HashSet, + order: VecDeque, +} + +impl SeenSessionIds { + fn remember(&mut self, session_id: SessionId) -> bool { + if !self.ids.insert(session_id.clone()) { + return false; + } + self.order.push_back(session_id); + while self.order.len() > MAX_RECENT_SESSION_IDS { + if let Some(expired) = self.order.pop_front() { + self.ids.remove(&expired); + } + } + true + } +} + +#[cfg(test)] +#[path = "host_tests.rs"] +mod tests; diff --git a/codex-rs/code-mode-host/src/main.rs b/codex-rs/code-mode-host/src/main.rs index f328e4d9d..b215869c8 100644 --- a/codex-rs/code-mode-host/src/main.rs +++ b/codex-rs/code-mode-host/src/main.rs @@ -1 +1,4 @@ -fn main() {} +#[tokio::main(flavor = "current_thread")] +async fn main() -> anyhow::Result<()> { + codex_code_mode_host::run_stdio().await +} diff --git a/codex-rs/code-mode-host/src/peer.rs b/codex-rs/code-mode-host/src/peer.rs new file mode 100644 index 000000000..14ff068e7 --- /dev/null +++ b/codex-rs/code-mode-host/src/peer.rs @@ -0,0 +1,541 @@ +use std::collections::HashMap; +use std::collections::VecDeque; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; +use std::sync::PoisonError; +use std::sync::atomic::AtomicI64; +use std::sync::atomic::Ordering; + +use codex_code_mode_protocol::CellId; +use codex_code_mode_protocol::StartedCell; +use codex_code_mode_protocol::host::DelegateRequest; +use codex_code_mode_protocol::host::DelegateRequestId; +use codex_code_mode_protocol::host::DelegateResponse; +use codex_code_mode_protocol::host::EncodedFrame; +use codex_code_mode_protocol::host::HostToClient; +use codex_code_mode_protocol::host::RequestId; +use codex_code_mode_protocol::host::SessionId; +use codex_code_mode_protocol::host::WireResult; +use tokio::sync::Mutex; +use tokio::sync::Notify; +use tokio::sync::OwnedSemaphorePermit; +use tokio::sync::Semaphore; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio_util::sync::CancellationToken; + +const CELL_MESSAGE_CAPACITY: usize = 128; +const MAX_PENDING_DELEGATE_CALLS: usize = 256; + +pub(super) struct HostPeer { + outgoing_tx: mpsc::Sender, + pending: Mutex>, + delegate_permits: Arc, + cell_routes: StdMutex>, + cell_routes_changed: Notify, + next_request_id: AtomicI64, + disconnected: CancellationToken, + failure: StdMutex>, +} + +struct PendingDelegate { + response_tx: oneshot::Sender>, + dispatched: bool, + _permit: OwnedSemaphorePermit, +} + +enum CellRoute { + Pending(VecDeque), + Active(mpsc::Sender), +} + +enum CellMessage { + Delegate { + id: DelegateRequestId, + request: DelegateRequest, + dispatched_tx: oneshot::Sender>, + }, + Closed, +} + +impl HostPeer { + pub(super) fn new(outgoing_tx: mpsc::Sender) -> Self { + Self { + outgoing_tx, + pending: Mutex::new(HashMap::new()), + delegate_permits: Arc::new(Semaphore::new(MAX_PENDING_DELEGATE_CALLS)), + cell_routes: StdMutex::new(HashMap::new()), + cell_routes_changed: Notify::new(), + next_request_id: AtomicI64::new(1), + disconnected: CancellationToken::new(), + failure: StdMutex::new(None), + } + } + + pub(super) fn send(&self, message: HostToClient) -> Result<(), PeerSendError> { + let frame = EncodedFrame::encode(&message) + .map_err(|err| PeerSendError::Payload(err.to_string()))?; + self.send_frame(frame) + } + + pub(super) fn respond( + &self, + id: RequestId, + result: Result, + ) { + let message = HostToClient::Response { + id, + result: WireResult::from_result(result), + }; + if let Err(PeerSendError::Payload(err)) = self.send(message) { + let _ = self.send(HostToClient::Response { + id, + result: WireResult::Err { + message: format!("code-mode host response exceeds the IPC frame limit: {err}"), + }, + }); + } + } + + fn initial_response( + &self, + id: RequestId, + result: Result, + ) { + let message = HostToClient::InitialResponse { + id, + result: WireResult::from_result(result), + }; + if let Err(PeerSendError::Payload(err)) = self.send(message) { + let _ = self.send(HostToClient::InitialResponse { + id, + result: WireResult::Err { + message: format!( + "code-mode initial response exceeds the IPC frame limit: {err}" + ), + }, + }); + } + } + + pub(super) async fn call( + self: &Arc, + session_id: SessionId, + request: DelegateRequest, + cancellation_token: CancellationToken, + ) -> Result { + if self.disconnected.is_cancelled() { + return Err("code-mode client connection closed".to_string()); + } + let Ok(permit) = Arc::clone(&self.delegate_permits).try_acquire_owned() else { + return Err("code-mode host has too many pending delegate calls".to_string()); + }; + let id = DelegateRequestId::new(self.next_request_id.fetch_add(1, Ordering::Relaxed)); + let (response_tx, response_rx) = oneshot::channel(); + self.pending.lock().await.insert( + id, + PendingDelegate { + response_tx, + dispatched: false, + _permit: permit, + }, + ); + let mut pending = PendingDelegateRequest::new(Arc::clone(self), id); + let cell_id = match &request { + DelegateRequest::InvokeTool { invocation } => invocation.cell_id.clone().into(), + DelegateRequest::Notify { cell_id, .. } => cell_id.clone().into(), + }; + let (dispatched_tx, dispatched_rx) = oneshot::channel(); + if let Err(err) = self.route_cell_message( + (session_id, cell_id), + CellMessage::Delegate { + id, + request, + dispatched_tx, + }, + ) { + self.pending.lock().await.remove(&id); + pending.disarm(); + return Err(err); + } + + let dispatched = tokio::select! { + dispatched = dispatched_rx => dispatched.map_err(|_| { + "code-mode cell route closed before dispatching delegate request".to_string() + })?, + _ = self.disconnected.cancelled() => { + self.pending.lock().await.remove(&id); + pending.disarm(); + return Err("code-mode client connection closed".to_string()); + } + }; + if let Err(err) = dispatched { + self.pending.lock().await.remove(&id); + pending.disarm(); + return Err(err); + } + + tokio::select! { + response = response_rx => { + pending.disarm(); + response.map_err(|_| { + "code-mode client closed before returning delegate output".to_string() + })? + } + _ = cancellation_token.cancelled() => { + if self.remove_pending(id).await.is_some() { + let _ = self.send(HostToClient::CancelDelegateRequest { id }); + } + pending.disarm(); + Err("code mode delegate request cancelled".to_string()) + } + _ = self.disconnected.cancelled() => { + self.pending.lock().await.remove(&id); + pending.disarm(); + Err("code-mode client connection closed".to_string()) + } + } + } + + pub(super) async fn complete( + &self, + id: DelegateRequestId, + response: Result, + ) { + if let Some(pending) = self.remove_pending(id).await { + let _ = pending.response_tx.send(response); + } + } + + pub(super) fn start_cell( + self: &Arc, + session_id: SessionId, + request_id: RequestId, + started: StartedCell, + active_cell_permit: OwnedSemaphorePermit, + ) -> oneshot::Receiver<()> { + let (initial_response_sent_tx, initial_response_sent_rx) = oneshot::channel(); + let key = (session_id, started.cell_id.clone()); + let (messages_tx, messages_rx) = mpsc::channel(CELL_MESSAGE_CAPACITY); + let previous = self + .cell_routes + .lock() + .unwrap_or_else(PoisonError::into_inner) + .insert(key.clone(), CellRoute::Active(messages_tx.clone())); + match previous { + Some(CellRoute::Pending(messages)) => { + for message in messages { + if messages_tx.try_send(message).is_err() { + self.disconnect(); + return initial_response_sent_rx; + } + } + } + Some(CellRoute::Active(_)) => { + self.disconnect(); + return initial_response_sent_rx; + } + None => {} + } + let peer = Arc::clone(self); + self.spawn_critical("cell forwarding", async move { + drive_cell( + peer, + key, + request_id, + started, + messages_rx, + initial_response_sent_tx, + active_cell_permit, + ) + .await; + }); + initial_response_sent_rx + } + + pub(super) fn close_cell(&self, session_id: SessionId, cell_id: CellId) { + let _ = self.route_cell_message((session_id, cell_id), CellMessage::Closed); + } + + pub(super) fn disconnect(&self) { + self.disconnected.cancel(); + } + + pub(super) fn fail(&self, reason: String) { + let mut failure = self.failure.lock().unwrap_or_else(PoisonError::into_inner); + if failure.is_none() { + *failure = Some(reason); + } + drop(failure); + self.disconnect(); + } + + pub(super) fn failure(&self) -> Option { + self.failure + .lock() + .unwrap_or_else(PoisonError::into_inner) + .clone() + } + + pub(super) fn is_disconnected(&self) -> bool { + self.disconnected.is_cancelled() + } + + pub(super) async fn disconnected(&self) { + self.disconnected.cancelled().await; + } + + pub(super) fn disconnection_token(&self) -> CancellationToken { + self.disconnected.clone() + } + + pub(super) async fn wait_for_session_cells(&self, session_id: &SessionId) { + loop { + let changed = self.cell_routes_changed.notified(); + if !self + .cell_routes + .lock() + .unwrap_or_else(PoisonError::into_inner) + .keys() + .any(|(route_session_id, _)| route_session_id == session_id) + { + return; + } + tokio::select! { + _ = changed => {} + _ = self.disconnected.cancelled() => return, + } + } + } + + async fn send_delegate_if_pending( + &self, + id: DelegateRequestId, + session_id: SessionId, + request: DelegateRequest, + dispatched_tx: oneshot::Sender>, + ) { + let result = { + let mut pending = self.pending.lock().await; + let Some(pending) = pending.get_mut(&id) else { + let _ = dispatched_tx.send(Err( + "code-mode delegate request was cancelled before dispatch".to_string(), + )); + return; + }; + match self.send(HostToClient::DelegateRequest { + id, + session_id, + request, + }) { + Ok(()) => { + pending.dispatched = true; + Ok(()) + } + Err(err) => Err(err.to_string()), + } + }; + let _ = dispatched_tx.send(result); + } + + fn route_cell_message( + &self, + key: (SessionId, CellId), + message: CellMessage, + ) -> Result<(), String> { + use std::collections::hash_map::Entry; + + let result = match self + .cell_routes + .lock() + .unwrap_or_else(PoisonError::into_inner) + .entry(key) + { + Entry::Occupied(mut entry) => match entry.get_mut() { + CellRoute::Pending(messages) if messages.len() < CELL_MESSAGE_CAPACITY => { + messages.push_back(message); + Ok(()) + } + CellRoute::Pending(_) => Err("code-mode cell message queue is full".to_string()), + CellRoute::Active(sender) => sender + .try_send(message) + .map_err(|_| "code-mode cell message queue is unavailable".to_string()), + }, + Entry::Vacant(entry) => { + entry.insert(CellRoute::Pending(VecDeque::from([message]))); + Ok(()) + } + }; + if result.is_err() { + self.disconnect(); + } + result + } + + async fn remove_pending(&self, id: DelegateRequestId) -> Option { + self.pending.lock().await.remove(&id) + } + + pub(super) fn spawn_critical(self: &Arc, task_name: &'static str, future: F) + where + F: std::future::Future + Send + 'static, + { + let task = tokio::spawn(future); + let peer = Arc::clone(self); + tokio::spawn(async move { + if let Err(err) = task.await { + peer.fail(format!("code-mode {task_name} task failed: {err}")); + } + }); + } + + fn send_frame(&self, frame: EncodedFrame) -> Result<(), PeerSendError> { + match self.outgoing_tx.try_send(frame) { + Ok(()) => Ok(()), + Err(mpsc::error::TrySendError::Full(_)) => { + self.disconnect(); + Err(PeerSendError::Unavailable( + "code-mode host outgoing queue is full".to_string(), + )) + } + Err(mpsc::error::TrySendError::Closed(_)) => { + self.disconnect(); + Err(PeerSendError::Unavailable( + "code-mode client connection closed".to_string(), + )) + } + } + } +} + +async fn drive_cell( + peer: Arc, + key: (SessionId, CellId), + request_id: RequestId, + started: StartedCell, + mut messages_rx: mpsc::Receiver, + initial_response_sent_tx: oneshot::Sender<()>, + _active_cell_permit: OwnedSemaphorePermit, +) { + let mut initial_response_sent_tx = Some(initial_response_sent_tx); + let initial_response = started.initial_response(); + tokio::pin!(initial_response); + let closed = loop { + tokio::select! { + biased; + result = &mut initial_response => { + peer.initial_response(request_id, result.map(Into::into)); + if let Some(initial_response_sent_tx) = initial_response_sent_tx.take() { + let _ = initial_response_sent_tx.send(()); + } + break false; + } + message = messages_rx.recv() => match message { + Some(CellMessage::Delegate { + id, + request, + dispatched_tx, + }) => { + peer.send_delegate_if_pending(id, key.0.clone(), request, dispatched_tx).await; + } + Some(CellMessage::Closed) | None => break true, + }, + _ = peer.disconnected.cancelled() => { + peer.remove_cell_route(&key); + return; + } + } + }; + + if closed { + peer.initial_response(request_id, initial_response.await.map(Into::into)); + if let Some(initial_response_sent_tx) = initial_response_sent_tx.take() { + let _ = initial_response_sent_tx.send(()); + } + } else { + loop { + tokio::select! { + message = messages_rx.recv() => match message { + Some(CellMessage::Delegate { + id, + request, + dispatched_tx, + }) => { + peer.send_delegate_if_pending(id, key.0.clone(), request, dispatched_tx).await; + } + Some(CellMessage::Closed) | None => break, + }, + _ = peer.disconnected.cancelled() => { + peer.remove_cell_route(&key); + return; + } + } + } + } + let _ = peer.send(HostToClient::CellClosed { + session_id: key.0.clone(), + cell_id: (&key.1).into(), + }); + peer.remove_cell_route(&key); +} + +impl HostPeer { + fn remove_cell_route(&self, key: &(SessionId, CellId)) { + let removed = self + .cell_routes + .lock() + .unwrap_or_else(PoisonError::into_inner) + .remove(key); + if removed.is_some() { + self.cell_routes_changed.notify_waiters(); + } + } +} + +pub(super) enum PeerSendError { + Payload(String), + Unavailable(String), +} + +impl std::fmt::Display for PeerSendError { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Payload(message) | Self::Unavailable(message) => formatter.write_str(message), + } + } +} + +struct PendingDelegateRequest { + peer: Arc, + id: Option, +} + +impl PendingDelegateRequest { + fn new(peer: Arc, id: DelegateRequestId) -> Self { + Self { peer, id: Some(id) } + } + + fn disarm(&mut self) { + self.id = None; + } +} + +impl Drop for PendingDelegateRequest { + fn drop(&mut self) { + let Some(id) = self.id.take() else { + return; + }; + let peer = Arc::clone(&self.peer); + tokio::spawn(async move { + if let Some(pending) = peer.remove_pending(id).await + && pending.dispatched + { + let _ = peer.send(HostToClient::CancelDelegateRequest { id }); + } + }); + } +} + +#[cfg(test)] +#[path = "peer_tests.rs"] +mod tests; diff --git a/codex-rs/code-mode-host/src/peer_tests.rs b/codex-rs/code-mode-host/src/peer_tests.rs new file mode 100644 index 000000000..a926027ce --- /dev/null +++ b/codex-rs/code-mode-host/src/peer_tests.rs @@ -0,0 +1,95 @@ +use std::sync::Arc; +use std::time::Duration; + +use codex_code_mode_protocol::CellId; +use codex_code_mode_protocol::RuntimeResponse; +use codex_code_mode_protocol::StartedCell; +use codex_code_mode_protocol::host::DelegateRequest; +use codex_code_mode_protocol::host::RequestId; +use codex_code_mode_protocol::host::SessionId; +use pretty_assertions::assert_eq; +use tokio::sync::Semaphore; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::sync::oneshot::error::TryRecvError; +use tokio_util::sync::CancellationToken; + +use super::HostPeer; +use super::MAX_PENDING_DELEGATE_CALLS; + +fn session_id(value: &str) -> SessionId { + SessionId::new(value).expect("session ID") +} + +#[tokio::test] +async fn start_cell_reports_when_initial_response_is_enqueued() { + let (outgoing_tx, mut outgoing_rx) = mpsc::channel(/*max_capacity*/ 4); + let peer = Arc::new(HostPeer::new(outgoing_tx)); + let cell_id = CellId::new("cell-1".to_string()); + let (response_tx, response_rx) = oneshot::channel(); + let started = StartedCell::new(cell_id.clone(), response_rx); + let active_cell_permits = Arc::new(Semaphore::new(/*permits*/ 1)); + let active_cell_permit = Arc::clone(&active_cell_permits) + .try_acquire_owned() + .expect("active cell permit"); + + let mut initial_response_sent = peer.start_cell( + session_id("session-1"), + RequestId::new(/*value*/ 1), + started, + active_cell_permit, + ); + assert_eq!(initial_response_sent.try_recv(), Err(TryRecvError::Empty)); + + response_tx + .send(RuntimeResponse::Result { + cell_id: cell_id.clone(), + content_items: Vec::new(), + error_text: None, + }) + .expect("initial response receiver"); + initial_response_sent + .await + .expect("initial response completion"); + outgoing_rx.recv().await.expect("initial response frame"); + assert_eq!(active_cell_permits.available_permits(), 0); + + peer.close_cell(session_id("session-1"), cell_id); + let permit = tokio::time::timeout( + Duration::from_secs(1), + Arc::clone(&active_cell_permits).acquire_owned(), + ) + .await + .expect("cell permit should be released") + .expect("cell permit semaphore should remain open"); + drop(permit); +} + +#[tokio::test] +async fn pending_delegate_limit_rejects_call_without_disconnecting() { + let (outgoing_tx, _outgoing_rx) = mpsc::channel(/*max_capacity*/ 1); + let peer = Arc::new(HostPeer::new(outgoing_tx)); + let permits = Arc::clone(&peer.delegate_permits) + .acquire_many_owned(MAX_PENDING_DELEGATE_CALLS as u32) + .await + .expect("delegate permits"); + + let result = peer + .call( + session_id("session-1"), + DelegateRequest::Notify { + call_id: "call-1".to_string(), + cell_id: CellId::new("cell-1".to_string()).into(), + text: "hello".to_string(), + }, + CancellationToken::new(), + ) + .await; + + assert_eq!( + result, + Err("code-mode host has too many pending delegate calls".to_string()) + ); + assert!(!peer.is_disconnected()); + drop(permits); +} diff --git a/justfile b/justfile index 03f87391a..d4a0e6566 100644 --- a/justfile +++ b/justfile @@ -31,6 +31,10 @@ tui-with-exec-server *args: file-search *args: cargo run --bin codex-file-search -- {args} +# Run the standalone code-mode host from source. +code-mode-host *args: + cargo run --bin codex-code-mode-host -- {args} + # Build the CLI and run the app-server test client app-server-test-client *args: cargo build -p codex-cli @@ -107,6 +111,16 @@ bazel-codex *args: bazel-codex *args: bazel run //codex-rs/cli:codex --run_under='cd /d "{{ invocation_directory_native() }}" &&' -- @($args | Select-Object -Skip 1) +# Build and run the standalone code-mode host from source using Bazel. +[no-cd] +[unix] +bazel-code-mode-host *args: + bazel run //codex-rs/code-mode-host:codex-code-mode-host --run_under="cd $PWD &&" -- "$@" + +[windows] +bazel-code-mode-host *args: + bazel run //codex-rs/code-mode-host:codex-code-mode-host --run_under='cd /d "{{ invocation_directory_native() }}" &&' -- @($args | Select-Object -Skip 1) + [no-cd] bazel-lock-update: bazel mod deps --lockfile_mode=update