diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index ab65fc0a4..4c22d8d73 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1409,6 +1409,7 @@ version = "0.0.0" dependencies = [ "anyhow", "assert_matches", + "async-channel", "async-trait", "base64 0.22.1", "bytes", @@ -1842,6 +1843,7 @@ dependencies = [ name = "codex-code-mode" version = "0.0.0" dependencies = [ + "async-channel", "async-trait", "codex-protocol", "deno_core_icudata", diff --git a/codex-rs/code-mode/Cargo.toml b/codex-rs/code-mode/Cargo.toml index d2f42359d..23b2ce230 100644 --- a/codex-rs/code-mode/Cargo.toml +++ b/codex-rs/code-mode/Cargo.toml @@ -13,6 +13,7 @@ path = "src/lib.rs" workspace = true [dependencies] +async-channel = { workspace = true } async-trait = { workspace = true } codex-protocol = { workspace = true } deno_core_icudata = { workspace = true } diff --git a/codex-rs/code-mode/src/service.rs b/codex-rs/code-mode/src/service.rs index 23ca7a746..79ca010c1 100644 --- a/codex-rs/code-mode/src/service.rs +++ b/codex-rs/code-mode/src/service.rs @@ -44,8 +44,8 @@ struct SessionHandle { struct Inner { stored_values: Mutex>, sessions: Mutex>, - turn_message_tx: mpsc::UnboundedSender, - turn_message_rx: Arc>>, + turn_message_tx: async_channel::Sender, + turn_message_rx: async_channel::Receiver, next_cell_id: AtomicU64, } @@ -55,14 +55,14 @@ pub struct CodeModeService { impl CodeModeService { pub fn new() -> Self { - let (turn_message_tx, turn_message_rx) = mpsc::unbounded_channel(); + let (turn_message_tx, turn_message_rx) = async_channel::unbounded(); Self { inner: Arc::new(Inner { stored_values: Mutex::new(HashMap::new()), sessions: Mutex::new(HashMap::new()), turn_message_tx, - turn_message_rx: Arc::new(Mutex::new(turn_message_rx)), + turn_message_rx, next_cell_id: AtomicU64::new(1), }), } @@ -146,16 +146,13 @@ impl CodeModeService { pub fn start_turn_worker(&self, host: Arc) -> CodeModeTurnWorker { let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); let inner = Arc::clone(&self.inner); - let turn_message_rx = Arc::clone(&self.inner.turn_message_rx); + let turn_message_rx = self.inner.turn_message_rx.clone(); tokio::spawn(async move { loop { let next_message = tokio::select! { _ = &mut shutdown_rx => break, - message = async { - let mut turn_message_rx = turn_message_rx.lock().await; - turn_message_rx.recv().await - } => message, + message = turn_message_rx.recv() => message.ok(), }; let Some(next_message) = next_message else { break; @@ -361,7 +358,7 @@ async fn run_session_control( cell_id: cell_id.clone(), call_id, text, - }); + }).await; } RuntimeEvent::ToolCall { id, name, input } => { let _ = inner.turn_message_tx.send(TurnMessage::ToolCall { @@ -369,7 +366,7 @@ async fn run_session_control( id, name, input, - }); + }).await; } RuntimeEvent::Result { stored_values, @@ -500,12 +497,12 @@ mod tests { } fn test_inner() -> Arc { - let (turn_message_tx, turn_message_rx) = mpsc::unbounded_channel(); + let (turn_message_tx, turn_message_rx) = async_channel::unbounded(); Arc::new(Inner { stored_values: Mutex::new(HashMap::new()), sessions: Mutex::new(HashMap::new()), turn_message_tx, - turn_message_rx: Arc::new(Mutex::new(turn_message_rx)), + turn_message_rx, next_cell_id: AtomicU64::new(1), }) } diff --git a/codex-rs/codex-api/Cargo.toml b/codex-rs/codex-api/Cargo.toml index 862d262b9..14340af1e 100644 --- a/codex-rs/codex-api/Cargo.toml +++ b/codex-rs/codex-api/Cargo.toml @@ -5,6 +5,7 @@ edition.workspace = true license.workspace = true [dependencies] +async-channel = { workspace = true } async-trait = { workspace = true } base64 = { workspace = true } bytes = { workspace = true } diff --git a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs index 29de90d40..656b00637 100644 --- a/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs +++ b/codex-rs/codex-api/src/endpoint/realtime_websocket/methods.rs @@ -65,9 +65,9 @@ enum WsCommand { impl WsStream { fn new( inner: WebSocketStream>, - ) -> (Self, mpsc::UnboundedReceiver>) { + ) -> (Self, async_channel::Receiver>) { let (tx_command, mut rx_command) = mpsc::channel::(32); - let (tx_message, rx_message) = mpsc::unbounded_channel::>(); + let (tx_message, rx_message) = async_channel::unbounded::>(); let pump_task = tokio::spawn(async move { let mut inner = inner; @@ -110,7 +110,7 @@ impl WsStream { trace!(payload_len = payload.len(), "realtime websocket received ping"); if let Err(err) = inner.send(Message::Pong(payload)).await { error!("realtime websocket failed to send pong: {err}"); - let _ = tx_message.send(Err(err)); + let _ = tx_message.send(Err(err)).await; break; } } @@ -138,7 +138,7 @@ impl WsStream { } Message::Ping(_) | Message::Pong(_) => {} } - if tx_message.send(Ok(message)).is_err() { + if tx_message.send(Ok(message)).await.is_err() { break; } if is_close { @@ -147,7 +147,7 @@ impl WsStream { } Err(err) => { error!("realtime websocket receive failed: {err}"); - let _ = tx_message.send(Err(err)); + let _ = tx_message.send(Err(err)).await; break; } } @@ -208,7 +208,7 @@ pub struct RealtimeWebsocketWriter { #[derive(Clone)] pub struct RealtimeWebsocketEvents { - rx_message: Arc>>>, + rx_message: async_channel::Receiver>, active_transcript: Arc>, event_parser: RealtimeEventParser, is_closed: Arc, @@ -256,7 +256,7 @@ impl RealtimeWebsocketConnection { fn new( stream: WsStream, - rx_message: mpsc::UnboundedReceiver>, + rx_message: async_channel::Receiver>, event_parser: RealtimeEventParser, ) -> Self { let stream = Arc::new(stream); @@ -268,7 +268,7 @@ impl RealtimeWebsocketConnection { event_parser, }, events: RealtimeWebsocketEvents { - rx_message: Arc::new(Mutex::new(rx_message)), + rx_message, active_transcript: Arc::new(Mutex::new(ActiveTranscriptState::default())), event_parser, is_closed, @@ -369,16 +369,16 @@ impl RealtimeWebsocketEvents { } loop { - let msg = match self.rx_message.lock().await.recv().await { - Some(Ok(msg)) => msg, - Some(Err(err)) => { + let msg = match self.rx_message.recv().await { + Ok(Ok(msg)) => msg, + Ok(Err(err)) => { self.is_closed.store(true, Ordering::SeqCst); error!("realtime websocket read failed: {err}"); return Err(ApiError::Stream(format!( "failed to read websocket message: {err}" ))); } - None => { + Err(_) => { self.is_closed.store(true, Ordering::SeqCst); info!("realtime websocket event stream ended"); return Ok(None);