From 83af3abc685149ede307cc754057bc597d7e0169 Mon Sep 17 00:00:00 2001 From: starr-openai Date: Tue, 19 May 2026 13:31:57 -0700 Subject: [PATCH] Refactor exec-server websocket pump (#23327) ## Why Exec-server websocket handling had separate reader and writer tasks for the same socket. That made websocket control-frame handling asymmetric: the task reading frames could observe `Ping`, but the task allowed to write frames was elsewhere. This PR moves each physical websocket onto one always-running pump so the socket owner can handle application frames and websocket control frames together. ## What changed - Refactored direct exec-server websocket connections in `connection.rs` to use one task that owns the websocket for outbound JSON-RPC, inbound JSON-RPC, periodic keepalive pings, and `Ping` -> `Pong` replies. - Refactored relay websocket handling in `relay.rs` the same way for both the harness-side logical connection and the multiplexed executor physical socket. - Preserved the existing keepalive ownership policy: outbound direct websocket clients still send periodic pings, inbound Axum accepts only reply with pongs, and relay physical websocket endpoints keep their existing periodic pings. - Added focused websocket pump tests for ping/pong, binary JSON-RPC, relay data, malformed relay text frames, and close/disconnect behavior. - Reconnect behavior is intentionally left for a follow-up. ## Validation - Devbox Bazel focused unit target: - `//codex-rs/exec-server:exec-server-unit-tests --test_filter='websocket_connection_|harness_connection_|multiplexed_executor_'` --- codex-rs/exec-server/src/connection.rs | 468 +++++++++++++------- codex-rs/exec-server/src/relay.rs | 568 ++++++++++++++++--------- 2 files changed, 697 insertions(+), 339 deletions(-) diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index cf504bbea..b211c504a 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -323,37 +323,20 @@ impl JsonRpcConnection { where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let (websocket_writer, websocket_reader) = stream.split(); - Self::from_websocket_parts( - websocket_writer, - websocket_reader, - connection_label, - Some(WEBSOCKET_KEEPALIVE_INTERVAL), - ) + Self::from_websocket_stream(stream, connection_label, /*ping_interval*/ None) } pub(crate) fn from_axum_websocket(stream: AxumWebSocket, connection_label: String) -> Self { - let (websocket_writer, websocket_reader) = stream.split(); - Self::from_websocket_parts( - websocket_writer, - websocket_reader, - connection_label, - // Axum only wraps inbound exec-server websocket accepts. Outbound websocket clients - // own keepalive pings so one side does not accidentally create redundant traffic. - /*keepalive_interval*/ - None, - ) + Self::from_websocket_stream(stream, connection_label, Some(WEBSOCKET_KEEPALIVE_INTERVAL)) } - fn from_websocket_parts( - mut websocket_writer: W, - mut websocket_reader: R, + fn from_websocket_stream( + mut websocket: T, connection_label: String, - keepalive_interval: Option, + ping_interval: Option, ) -> Self where - W: Sink + Unpin + Send + 'static, - R: Stream> + Unpin + Send + 'static, + T: Sink + Stream> + Unpin + Send + 'static, M: JsonRpcWebSocketMessage, E: std::fmt::Display + Send + 'static, { @@ -361,118 +344,106 @@ impl JsonRpcConnection { let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); let (disconnected_tx, disconnected_rx) = watch::channel(false); - let reader_label = connection_label.clone(); - let incoming_tx_for_reader = incoming_tx.clone(); - let disconnected_tx_for_reader = disconnected_tx.clone(); - let reader_task = tokio::spawn(async move { + let websocket_task = tokio::spawn(async move { + let mut ping_interval = ping_interval.map(|ping_interval| { + let mut interval = tokio::time::interval_at( + tokio::time::Instant::now() + ping_interval, + ping_interval, + ); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + interval + }); + loop { - match websocket_reader.next().await { - Some(Ok(message)) => match message.parse_jsonrpc_frame() { - Ok(JsonRpcWebSocketFrame::Message(message)) => { - if incoming_tx_for_reader - .send(JsonRpcConnectionEvent::Message(message)) - .await - .is_err() - { - break; - } + tokio::select! { + maybe_message = outgoing_rx.recv() => { + let Some(message) = maybe_message else { + break; + }; + if let Err(reason) = send_websocket_jsonrpc_message( + &mut websocket, + &connection_label, + &message, + ) + .await + { + send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await; + break; } - Err(err) => { - send_malformed_message( - &incoming_tx_for_reader, - Some(format!( - "failed to parse websocket JSON-RPC message from {reader_label}: {err}" - )), - ) - .await; + } + _ = async { + match ping_interval.as_mut() { + Some(interval) => interval.tick().await, + None => std::future::pending().await, } - Ok(JsonRpcWebSocketFrame::Close) => { + } => { + if let Err(err) = websocket.send(M::ping()).await { send_disconnected( - &incoming_tx_for_reader, - &disconnected_tx_for_reader, - /*reason*/ None, + &incoming_tx, + &disconnected_tx, + Some(format!( + "failed to write websocket ping to {connection_label}: {err}" + )), ) .await; break; } - Ok(JsonRpcWebSocketFrame::Ignore) => {} - }, - Some(Err(err)) => { - send_disconnected( - &incoming_tx_for_reader, - &disconnected_tx_for_reader, - Some(format!( - "failed to read websocket JSON-RPC message from {reader_label}: {err}" - )), - ) - .await; - break; } - None => { - send_disconnected( - &incoming_tx_for_reader, - &disconnected_tx_for_reader, - /*reason*/ None, - ) - .await; - break; - } - } - } - }); - - let writer_task = tokio::spawn(async move { - if let Some(keepalive_interval) = keepalive_interval { - let mut keepalive = tokio::time::interval_at( - tokio::time::Instant::now() + keepalive_interval, - keepalive_interval, - ); - keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - loop { - tokio::select! { - maybe_message = outgoing_rx.recv() => { - let Some(message) = maybe_message else { - break; - }; - if let Err(reason) = send_websocket_jsonrpc_message( - &mut websocket_writer, - &connection_label, - &message, - ) - .await - { - send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await; - break; - } - } - _ = keepalive.tick() => { - if let Err(err) = websocket_writer.send(M::ping()).await { + incoming_message = websocket.next() => { + match incoming_message { + Some(Ok(message)) => match message.parse_jsonrpc_frame() { + Ok(JsonRpcWebSocketFrame::Message(message)) => { + if incoming_tx + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + break; + } + } + Ok(JsonRpcWebSocketFrame::Close) => { + send_disconnected( + &incoming_tx, + &disconnected_tx, + /*reason*/ None, + ) + .await; + break; + } + Ok(JsonRpcWebSocketFrame::Ignore) => {} + Err(err) => { + send_malformed_message( + &incoming_tx, + Some(format!( + "failed to parse websocket JSON-RPC message from {connection_label}: {err}" + )), + ) + .await; + } + }, + Some(Err(err)) => { send_disconnected( &incoming_tx, &disconnected_tx, Some(format!( - "failed to write websocket ping to {connection_label}: {err}" + "failed to read websocket JSON-RPC message from {connection_label}: {err}" )), ) .await; break; } + None => { + send_disconnected( + &incoming_tx, + &disconnected_tx, + /*reason*/ None, + ) + .await; + break; + } } } } - } else { - while let Some(message) = outgoing_rx.recv().await { - if let Err(reason) = send_websocket_jsonrpc_message( - &mut websocket_writer, - &connection_label, - &message, - ) - .await - { - send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await; - break; - } - } } }); @@ -480,7 +451,7 @@ impl JsonRpcConnection { outgoing_tx, incoming_rx, disconnected_rx, - task_handles: vec![reader_task, writer_task], + task_handles: vec![websocket_task], transport: JsonRpcTransport::Plain, } } @@ -619,34 +590,250 @@ fn serialize_jsonrpc_message(message: &JSONRPCMessage) -> Result, + #[tokio::test] + async fn websocket_connection_sends_configured_ping() -> anyhow::Result<()> { + let (client_websocket, mut server_websocket) = websocket_pair().await?; + let connection = JsonRpcConnection::from_websocket_stream( + client_websocket, + "test".into(), + Some(WEBSOCKET_KEEPALIVE_INTERVAL), + ); + + let message = timeout(Duration::from_secs(1), server_websocket.next()) + .await? + .expect("websocket should stay open")?; + assert!(matches!(message, Message::Ping(_))); + + drop(connection); + Ok(()) } - impl Sink for TestWebSocketSink { + #[tokio::test] + async fn websocket_connection_ignores_server_pong() -> anyhow::Result<()> { + let (client_websocket, mut server_websocket) = websocket_pair().await?; + let mut connection = JsonRpcConnection::from_websocket(client_websocket, "test".into()); + + server_websocket + .send(Message::Pong(b"check".to_vec().into())) + .await?; + assert!( + timeout(Duration::from_millis(50), connection.incoming_rx.recv()) + .await + .is_err() + ); + + drop(connection); + Ok(()) + } + + #[tokio::test] + async fn websocket_connection_reports_server_close() -> anyhow::Result<()> { + let (client_websocket, mut server_websocket) = websocket_pair().await?; + let mut connection = JsonRpcConnection::from_websocket(client_websocket, "test".into()); + + server_websocket.close(None).await?; + assert!(matches!( + timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?, + Some(JsonRpcConnectionEvent::Disconnected { reason: None }) + )); + + drop(connection); + Ok(()) + } + + #[tokio::test] + async fn websocket_connection_accepts_binary_jsonrpc_message() -> anyhow::Result<()> { + let (client_websocket, mut server_websocket) = websocket_pair().await?; + let mut connection = JsonRpcConnection::from_websocket(client_websocket, "test".into()); + let message = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(1), + method: "test".to_string(), + params: None, + trace: None, + }); + + server_websocket + .send(Message::Binary(serde_json::to_vec(&message)?.into())) + .await?; + assert!(matches!( + timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?, + Some(JsonRpcConnectionEvent::Message(actual)) if actual == message + )); + + drop(connection); + Ok(()) + } + + #[tokio::test] + async fn websocket_connection_keeps_outbound_message_while_send_is_backpressured() + -> anyhow::Result<()> { + let (websocket, control, mut outbound_rx) = + ControlledWebSocket::new(/*write_ready*/ false); + let mut connection = JsonRpcConnection::from_websocket_stream( + websocket, + "test".into(), + /*ping_interval*/ None, + ); + let message = test_jsonrpc_message(); + + connection.outgoing_tx.send(message.clone()).await?; + control.wait_for_blocked_write().await?; + control.send_inbound(Message::Pong(b"check".to_vec().into()))?; + assert!( + timeout(Duration::from_millis(50), connection.incoming_rx.recv()) + .await + .is_err() + ); + + control.set_write_ready(); + assert!(matches!( + timeout(Duration::from_secs(1), outbound_rx.next()).await?, + Some(Message::Text(text)) if serde_json::from_str::(&text)? == message + )); + drop(connection); + Ok(()) + } + + async fn websocket_pair() -> anyhow::Result<( + WebSocketStream>, + WebSocketStream, + )> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let websocket_url = format!("ws://{}", listener.local_addr()?); + let server_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await?; + accept_async(stream).await.map_err(anyhow::Error::from) + }); + let (client_websocket, _) = connect_async(websocket_url).await?; + let server_websocket = server_task.await??; + Ok((client_websocket, server_websocket)) + } + + fn test_jsonrpc_message() -> JSONRPCMessage { + JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(1), + method: "test".to_string(), + params: None, + trace: None, + }) + } + + struct ControlledWebSocket { + inbound_rx: futures_mpsc::UnboundedReceiver>, + outbound_tx: futures_mpsc::UnboundedSender, + write_ready: Arc, + write_blocked: Arc, + write_blocked_waker: Arc, + write_waker: Arc, + } + + struct ControlledWebSocketHandle { + inbound_tx: futures_mpsc::UnboundedSender>, + write_ready: Arc, + write_blocked: Arc, + write_blocked_waker: Arc, + write_waker: Arc, + } + + impl ControlledWebSocket { + fn new( + write_ready: bool, + ) -> ( + Self, + ControlledWebSocketHandle, + futures_mpsc::UnboundedReceiver, + ) { + let (inbound_tx, inbound_rx) = futures_mpsc::unbounded(); + let (outbound_tx, outbound_rx) = futures_mpsc::unbounded(); + let write_ready = Arc::new(AtomicBool::new(write_ready)); + let write_blocked = Arc::new(AtomicBool::new(false)); + let write_blocked_waker = Arc::new(AtomicWaker::new()); + let write_waker = Arc::new(AtomicWaker::new()); + ( + Self { + inbound_rx, + outbound_tx, + write_ready: Arc::clone(&write_ready), + write_blocked: Arc::clone(&write_blocked), + write_blocked_waker: Arc::clone(&write_blocked_waker), + write_waker: Arc::clone(&write_waker), + }, + ControlledWebSocketHandle { + inbound_tx, + write_ready, + write_blocked, + write_blocked_waker, + write_waker, + }, + outbound_rx, + ) + } + } + + impl ControlledWebSocketHandle { + fn send_inbound(&self, message: Message) -> anyhow::Result<()> { + self.inbound_tx + .unbounded_send(Ok(message)) + .map_err(anyhow::Error::from) + } + + fn set_write_ready(&self) { + self.write_ready.store(true, Ordering::Release); + self.write_waker.wake(); + } + + async fn wait_for_blocked_write(&self) -> anyhow::Result<()> { + timeout( + Duration::from_secs(1), + futures::future::poll_fn(|cx| { + if self.write_blocked.load(Ordering::Acquire) { + Poll::Ready(()) + } else { + self.write_blocked_waker.register(cx.waker()); + Poll::Pending + } + }), + ) + .await?; + Ok(()) + } + } + + impl Sink for ControlledWebSocket { type Error = std::convert::Infallible; - fn poll_ready( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.write_ready.load(Ordering::Acquire) { + Poll::Ready(Ok(())) + } else { + self.write_blocked.store(true, Ordering::Release); + self.write_blocked_waker.wake(); + self.write_waker.register(cx.waker()); + Poll::Pending + } } fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { - self.get_mut() - .message_tx + self.outbound_tx .unbounded_send(item) - .expect("test websocket receiver should stay open"); + .expect("test outbound receiver should stay open"); Ok(()) } @@ -665,24 +852,11 @@ mod tests { } } - #[tokio::test] - async fn websocket_connection_sends_keepalive_ping() { - let (message_tx, mut message_rx) = futures_mpsc::unbounded::(); - let websocket_writer = TestWebSocketSink { message_tx }; - let websocket_reader = stream::pending::>(); - let connection = JsonRpcConnection::from_websocket_parts( - websocket_writer, - websocket_reader, - "test".into(), - Some(WEBSOCKET_KEEPALIVE_INTERVAL), - ); + impl Stream for ControlledWebSocket { + type Item = Result; - let message = timeout(Duration::from_secs(1), message_rx.next()) - .await - .expect("keepalive ping should arrive before timeout") - .expect("keepalive ping should be sent"); - assert!(matches!(message, Message::Ping(_))); - - drop(connection); + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inbound_rx).poll_next(cx) + } } } diff --git a/codex-rs/exec-server/src/relay.rs b/codex-rs/exec-server/src/relay.rs index 7470a6290..b3708f1d8 100644 --- a/codex-rs/exec-server/src/relay.rs +++ b/codex-rs/exec-server/src/relay.rs @@ -1,7 +1,9 @@ use std::collections::HashMap; use codex_app_server_protocol::JSONRPCMessage; +use futures::Sink; use futures::SinkExt; +use futures::Stream; use futures::StreamExt; use prost::Message as ProstMessage; use tokio::io::AsyncRead; @@ -19,7 +21,6 @@ use crate::connection::CHANNEL_CAPACITY; use crate::connection::JsonRpcConnection; use crate::connection::JsonRpcConnectionEvent; use crate::connection::JsonRpcTransport; -use crate::connection::WEBSOCKET_KEEPALIVE_INTERVAL; use crate::relay_proto::RelayData; use crate::relay_proto::RelayMessageFrame; use crate::relay_proto::RelayResume; @@ -140,121 +141,25 @@ fn jsonrpc_payload(message: &JSONRPCMessage) -> Result, ExecServerError> serde_json::to_vec(message).map_err(ExecServerError::Json) } -pub(crate) fn harness_connection_from_websocket( - stream: WebSocketStream, +pub(crate) fn harness_connection_from_websocket( + stream: T, connection_label: String, ) -> JsonRpcConnection where - S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + T: Sink + Stream> + Unpin + Send + 'static, + E: std::fmt::Display + Send + 'static, { let stream_id = Uuid::new_v4().to_string(); - let (mut websocket_writer, mut websocket_reader) = stream.split(); let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); let (disconnected_tx, disconnected_rx) = watch::channel(false); - let reader_label = connection_label; - let reader_stream_id = stream_id.clone(); - let incoming_tx_for_reader = incoming_tx; - let disconnected_tx_for_reader = disconnected_tx.clone(); - let reader_task = tokio::spawn(async move { - loop { - match websocket_reader.next().await { - Some(Ok(Message::Binary(payload))) => { - let frame = match decode_relay_message_frame(payload.as_ref()) { - Ok(frame) => frame, - Err(err) => { - let _ = incoming_tx_for_reader - .send(JsonRpcConnectionEvent::MalformedMessage { - reason: format!( - "failed to parse relay message frame from {reader_label}: {err}" - ), - }) - .await; - continue; - } - }; - if frame.stream_id != reader_stream_id { - continue; - } - let kind = match frame.validate() { - Ok(kind) => kind, - Err(err) => { - let _ = incoming_tx_for_reader - .send(JsonRpcConnectionEvent::MalformedMessage { - reason: err.to_string(), - }) - .await; - continue; - } - }; - match kind { - RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() { - Ok(message) => { - if incoming_tx_for_reader - .send(JsonRpcConnectionEvent::Message(message)) - .await - .is_err() - { - break; - } - } - Err(err) => { - let _ = incoming_tx_for_reader - .send(JsonRpcConnectionEvent::MalformedMessage { - reason: err.to_string(), - }) - .await; - } - }, - RelayFrameBodyKind::Reset => { - let _ = disconnected_tx_for_reader.send(true); - let _ = incoming_tx_for_reader - .send(JsonRpcConnectionEvent::Disconnected { - reason: frame.into_reset_reason(), - }) - .await; - break; - } - RelayFrameBodyKind::Ack - | RelayFrameBodyKind::Resume - | RelayFrameBodyKind::Heartbeat => {} - } - } - Some(Ok(Message::Close(_))) | None => { - let _ = disconnected_tx_for_reader.send(true); - let _ = incoming_tx_for_reader - .send(JsonRpcConnectionEvent::Disconnected { reason: None }) - .await; - break; - } - Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {} - Some(Ok(Message::Text(_))) => { - let _ = incoming_tx_for_reader - .send(JsonRpcConnectionEvent::MalformedMessage { - reason: "relay exec-server transport expects binary protobuf frames" - .to_string(), - }) - .await; - } - Some(Err(err)) => { - let _ = disconnected_tx_for_reader.send(true); - let _ = incoming_tx_for_reader - .send(JsonRpcConnectionEvent::Disconnected { - reason: Some(format!( - "failed to read relay websocket frame from {reader_label}: {err}" - )), - }) - .await; - break; - } - } - } - }); - - let writer_task = tokio::spawn(async move { + let websocket_task = tokio::spawn(async move { + let mut websocket = stream; + let reader_label = connection_label; + let reader_stream_id = stream_id.clone(); let resume = RelayMessageFrame::resume(stream_id.clone()); - if websocket_writer + if websocket .send(Message::Binary(encode_relay_message_frame(&resume).into())) .await .is_err() @@ -263,11 +168,6 @@ where return; } - let mut keepalive = tokio::time::interval_at( - tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL, - WEBSOCKET_KEEPALIVE_INTERVAL, - ); - keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); let mut next_seq = 0u32; loop { tokio::select! { @@ -284,7 +184,7 @@ where }; let frame = RelayMessageFrame::data(stream_id.clone(), next_seq, payload); next_seq = next_seq.wrapping_add(1); - if websocket_writer + if websocket .send(Message::Binary(encode_relay_message_frame(&frame).into())) .await .is_err() @@ -293,10 +193,96 @@ where break; } } - _ = keepalive.tick() => { - if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() { - let _ = disconnected_tx.send(true); - break; + incoming_message = websocket.next() => { + match incoming_message { + Some(Ok(Message::Binary(payload))) => { + let frame = match decode_relay_message_frame(payload.as_ref()) { + Ok(frame) => frame, + Err(err) => { + let _ = incoming_tx + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: format!( + "failed to parse relay message frame from {reader_label}: {err}" + ), + }) + .await; + continue; + } + }; + if frame.stream_id != reader_stream_id { + continue; + } + let kind = match frame.validate() { + Ok(kind) => kind, + Err(err) => { + let _ = incoming_tx + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: err.to_string(), + }) + .await; + continue; + } + }; + match kind { + RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() { + Ok(message) => { + if incoming_tx + .send(JsonRpcConnectionEvent::Message(message)) + .await + .is_err() + { + break; + } + } + Err(err) => { + let _ = incoming_tx + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: err.to_string(), + }) + .await; + } + }, + RelayFrameBodyKind::Reset => { + let _ = disconnected_tx.send(true); + let _ = incoming_tx + .send(JsonRpcConnectionEvent::Disconnected { + reason: frame.into_reset_reason(), + }) + .await; + break; + } + RelayFrameBodyKind::Ack + | RelayFrameBodyKind::Resume + | RelayFrameBodyKind::Heartbeat => {} + } + } + Some(Ok(Message::Close(_))) | None => { + let _ = disconnected_tx.send(true); + let _ = incoming_tx + .send(JsonRpcConnectionEvent::Disconnected { reason: None }) + .await; + break; + } + Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {} + Some(Ok(Message::Text(_))) => { + let _ = incoming_tx + .send(JsonRpcConnectionEvent::MalformedMessage { + reason: "relay exec-server transport expects binary protobuf frames" + .to_string(), + }) + .await; + } + Some(Err(err)) => { + let _ = disconnected_tx.send(true); + let _ = incoming_tx + .send(JsonRpcConnectionEvent::Disconnected { + reason: Some(format!( + "failed to read relay websocket frame from {reader_label}: {err}" + )), + }) + .await; + break; + } } } } @@ -307,7 +293,7 @@ where outgoing_tx, incoming_rx, disconnected_rx, - task_handles: vec![reader_task, writer_task], + task_handles: vec![websocket_task], transport: JsonRpcTransport::Plain, } } @@ -318,59 +304,42 @@ pub(crate) async fn run_multiplexed_executor( ) where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let (mut websocket_writer, mut websocket_reader) = stream.split(); + let mut websocket = stream; let (physical_outgoing_tx, mut physical_outgoing_rx) = mpsc::channel::>(CHANNEL_CAPACITY); - let writer_task = tokio::spawn(async move { - let mut keepalive = tokio::time::interval_at( - tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL, - WEBSOCKET_KEEPALIVE_INTERVAL, - ); - keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - loop { - tokio::select! { - maybe_encoded = physical_outgoing_rx.recv() => { - let Some(encoded) = maybe_encoded else { - break; - }; - if websocket_writer - .send(Message::Binary(encoded.into())) - .await - .is_err() - { - break; - } - } - _ = keepalive.tick() => { - if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() { - break; - } - } - } - } - }); let mut streams: HashMap = HashMap::new(); loop { - let frame = match websocket_reader.next().await { - Some(Ok(Message::Binary(payload))) => { - match decode_relay_message_frame(payload.as_ref()) { - Ok(frame) => frame, - Err(err) => { - warn!("dropping malformed relay message frame from harness: {err}"); - continue; - } + let frame = tokio::select! { + maybe_encoded = physical_outgoing_rx.recv() => { + let Some(encoded) = maybe_encoded else { + break; + }; + if websocket.send(Message::Binary(encoded.into())).await.is_err() { + break; } - } - Some(Ok(Message::Close(_))) | None => break, - Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue, - Some(Ok(Message::Text(_))) => { - warn!("dropping non-binary relay message frame from harness"); continue; } - Some(Err(err)) => { - debug!("multiplexed executor websocket read failed: {err}"); - break; + incoming_message = websocket.next() => match incoming_message { + Some(Ok(Message::Binary(payload))) => { + match decode_relay_message_frame(payload.as_ref()) { + Ok(frame) => frame, + Err(err) => { + warn!("dropping malformed relay message frame from harness: {err}"); + continue; + } + } + } + Some(Ok(Message::Close(_))) | None => break, + Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue, + Some(Ok(Message::Text(_))) => { + warn!("dropping non-binary relay message frame from harness"); + continue; + } + Some(Err(err)) => { + debug!("multiplexed executor websocket read failed: {err}"); + break; + } } }; @@ -423,7 +392,6 @@ pub(crate) async fn run_multiplexed_executor( stream.disconnect(/*reason*/ None).await; } drop(physical_outgoing_tx); - let _ = writer_task.await; } struct VirtualStream { @@ -492,8 +460,20 @@ fn spawn_virtual_stream( #[cfg(test)] mod tests { + use std::pin::Pin; + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::Ordering; + use std::task::Context; + use std::task::Poll; use std::time::Duration; + use codex_app_server_protocol::JSONRPCRequest; + use codex_app_server_protocol::RequestId; + use futures::Sink; + use futures::Stream; + use futures::channel::mpsc as futures_mpsc; + use futures::task::AtomicWaker; use tokio::net::TcpListener; use tokio::time::timeout; use tokio_tungstenite::accept_async; @@ -502,40 +482,107 @@ mod tests { use super::*; - fn test_runtime_paths() -> anyhow::Result { - crate::ExecServerRuntimePaths::new( - std::env::current_exe()?, - /*codex_linux_sandbox_exe*/ None, - ) - .map_err(anyhow::Error::from) - } - #[tokio::test] - async fn multiplexed_executor_sends_keepalive_ping() -> anyhow::Result<()> { + async fn harness_connection_receives_relay_data() -> anyhow::Result<()> { let (client_websocket, mut server_websocket) = websocket_pair().await?; - let executor_task = tokio::spawn(run_multiplexed_executor( - client_websocket, - ConnectionProcessor::new(test_runtime_paths()?), + let mut connection = + harness_connection_from_websocket(client_websocket, "test".to_string()); + let stream_id = read_resume_stream_id(&mut server_websocket).await?; + let message = test_jsonrpc_message(); + + server_websocket + .send(Message::Binary( + encode_relay_message_frame(&RelayMessageFrame::data( + stream_id, + /*seq*/ 0, + jsonrpc_payload(&message)?, + )) + .into(), + )) + .await?; + assert!(matches!( + timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?, + Some(JsonRpcConnectionEvent::Message(actual)) if actual == message )); - read_keepalive_ping(&mut server_websocket).await?; - - executor_task.abort(); - let _ = executor_task.await; + drop(connection); Ok(()) } #[tokio::test] - async fn harness_connection_sends_keepalive_ping() -> anyhow::Result<()> { + async fn harness_connection_reports_text_frames_as_malformed() -> anyhow::Result<()> { let (client_websocket, mut server_websocket) = websocket_pair().await?; - let connection = harness_connection_from_websocket(client_websocket, "test".to_string()); + let mut connection = + harness_connection_from_websocket(client_websocket, "test".to_string()); - read_keepalive_ping(&mut server_websocket).await?; + read_resume_stream_id(&mut server_websocket).await?; + server_websocket.send(Message::Text("nope".into())).await?; + assert!(matches!( + timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?, + Some(JsonRpcConnectionEvent::MalformedMessage { reason }) + if reason == "relay exec-server transport expects binary protobuf frames" + )); drop(connection); Ok(()) } + #[tokio::test] + async fn harness_connection_reports_server_close() -> anyhow::Result<()> { + let (client_websocket, mut server_websocket) = websocket_pair().await?; + let mut connection = + harness_connection_from_websocket(client_websocket, "test".to_string()); + + read_resume_stream_id(&mut server_websocket).await?; + server_websocket.close(None).await?; + assert!(matches!( + timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?, + Some(JsonRpcConnectionEvent::Disconnected { reason: None }) + )); + + drop(connection); + Ok(()) + } + + #[tokio::test] + async fn harness_connection_keeps_outbound_frame_while_send_is_backpressured() + -> anyhow::Result<()> { + let (websocket, control, mut outbound_rx) = + ControlledWebSocket::new(/*write_ready*/ true); + let mut connection = harness_connection_from_websocket(websocket, "test".to_string()); + let Message::Binary(resume_payload) = timeout(Duration::from_secs(1), outbound_rx.next()) + .await? + .expect("resume frame") + else { + anyhow::bail!("expected relay resume frame"); + }; + let stream_id = decode_relay_message_frame(resume_payload.as_ref())?.stream_id; + let message = test_jsonrpc_message(); + + control.set_write_blocked(); + connection.outgoing_tx.send(message.clone()).await?; + control.wait_for_blocked_write().await?; + control.send_inbound(Message::Pong(b"check".to_vec().into()))?; + assert!( + timeout(Duration::from_millis(50), connection.incoming_rx.recv()) + .await + .is_err() + ); + + control.set_write_ready(); + let Message::Binary(data_payload) = timeout(Duration::from_secs(1), outbound_rx.next()) + .await? + .expect("data frame") + else { + anyhow::bail!("expected relay data frame"); + }; + let frame = decode_relay_message_frame(data_payload.as_ref())?; + assert_eq!(frame.stream_id, stream_id); + assert_eq!(frame.into_jsonrpc_message()?, message); + drop(connection); + Ok(()) + } + async fn websocket_pair() -> anyhow::Result<( WebSocketStream>, WebSocketStream, @@ -551,18 +598,155 @@ mod tests { Ok((client_websocket, server_websocket)) } - async fn read_keepalive_ping( + async fn read_resume_stream_id( websocket: &mut WebSocketStream, - ) -> anyhow::Result<()> { - loop { - let Some(message) = timeout(Duration::from_secs(1), websocket.next()).await? else { - anyhow::bail!("websocket closed before keepalive ping"); - }; - match message? { - Message::Ping(_) => return Ok(()), - Message::Binary(_) | Message::Text(_) | Message::Pong(_) | Message::Frame(_) => {} - Message::Close(_) => anyhow::bail!("websocket closed before keepalive ping"), + ) -> anyhow::Result { + let message = timeout(Duration::from_secs(1), websocket.next()) + .await? + .expect("websocket should stay open")?; + let Message::Binary(payload) = message else { + anyhow::bail!("expected relay resume frame, got {message:?}"); + }; + let frame = decode_relay_message_frame(payload.as_ref())?; + assert_eq!(frame.validate()?, RelayFrameBodyKind::Resume); + Ok(frame.stream_id) + } + + fn test_jsonrpc_message() -> JSONRPCMessage { + JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(1), + method: "test".to_string(), + params: None, + trace: None, + }) + } + + struct ControlledWebSocket { + inbound_rx: futures_mpsc::UnboundedReceiver>, + outbound_tx: futures_mpsc::UnboundedSender, + write_ready: Arc, + write_blocked: Arc, + write_blocked_waker: Arc, + write_waker: Arc, + } + + struct ControlledWebSocketHandle { + inbound_tx: futures_mpsc::UnboundedSender>, + write_ready: Arc, + write_blocked: Arc, + write_blocked_waker: Arc, + write_waker: Arc, + } + + impl ControlledWebSocket { + fn new( + write_ready: bool, + ) -> ( + Self, + ControlledWebSocketHandle, + futures_mpsc::UnboundedReceiver, + ) { + let (inbound_tx, inbound_rx) = futures_mpsc::unbounded(); + let (outbound_tx, outbound_rx) = futures_mpsc::unbounded(); + let write_ready = Arc::new(AtomicBool::new(write_ready)); + let write_blocked = Arc::new(AtomicBool::new(false)); + let write_blocked_waker = Arc::new(AtomicWaker::new()); + let write_waker = Arc::new(AtomicWaker::new()); + ( + Self { + inbound_rx, + outbound_tx, + write_ready: Arc::clone(&write_ready), + write_blocked: Arc::clone(&write_blocked), + write_blocked_waker: Arc::clone(&write_blocked_waker), + write_waker: Arc::clone(&write_waker), + }, + ControlledWebSocketHandle { + inbound_tx, + write_ready, + write_blocked, + write_blocked_waker, + write_waker, + }, + outbound_rx, + ) + } + } + + impl ControlledWebSocketHandle { + fn send_inbound(&self, message: Message) -> anyhow::Result<()> { + self.inbound_tx + .unbounded_send(Ok(message)) + .map_err(anyhow::Error::from) + } + + fn set_write_blocked(&self) { + self.write_ready.store(false, Ordering::Release); + } + + fn set_write_ready(&self) { + self.write_ready.store(true, Ordering::Release); + self.write_waker.wake(); + } + + async fn wait_for_blocked_write(&self) -> anyhow::Result<()> { + timeout( + Duration::from_secs(1), + futures::future::poll_fn(|cx| { + if self.write_blocked.load(Ordering::Acquire) { + Poll::Ready(()) + } else { + self.write_blocked_waker.register(cx.waker()); + Poll::Pending + } + }), + ) + .await?; + Ok(()) + } + } + + impl Sink for ControlledWebSocket { + type Error = std::convert::Infallible; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.write_ready.load(Ordering::Acquire) { + Poll::Ready(Ok(())) + } else { + self.write_blocked.store(true, Ordering::Release); + self.write_blocked_waker.wake(); + self.write_waker.register(cx.waker()); + Poll::Pending } } + + fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + self.outbound_tx + .unbounded_send(item) + .expect("test outbound receiver should stay open"); + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl Stream for ControlledWebSocket { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inbound_rx).poll_next(cx) + } } }