Refactor exec-server websocket pump (#23327)

## Why
Exec-server websocket handling had separate reader and writer tasks for
the same socket. That made websocket control-frame handling asymmetric:
the task reading frames could observe `Ping`, but the task allowed to
write frames was elsewhere. This PR moves each physical websocket onto
one always-running pump so the socket owner can handle application
frames and websocket control frames together.

## What changed
- Refactored direct exec-server websocket connections in `connection.rs`
to use one task that owns the websocket for outbound JSON-RPC, inbound
JSON-RPC, periodic keepalive pings, and `Ping` -> `Pong` replies.
- Refactored relay websocket handling in `relay.rs` the same way for
both the harness-side logical connection and the multiplexed executor
physical socket.
- Preserved the existing keepalive ownership policy: outbound direct
websocket clients still send periodic pings, inbound Axum accepts only
reply with pongs, and relay physical websocket endpoints keep their
existing periodic pings.
- Added focused websocket pump tests for ping/pong, binary JSON-RPC,
relay data, malformed relay text frames, and close/disconnect behavior.
- Reconnect behavior is intentionally left for a follow-up.

## Validation
- Devbox Bazel focused unit target:
- `//codex-rs/exec-server:exec-server-unit-tests
--test_filter='websocket_connection_|harness_connection_|multiplexed_executor_'`
This commit is contained in:
starr-openai
2026-05-19 13:31:57 -07:00
committed by GitHub
Unverified
parent 5c43a64e2b
commit 83af3abc68
2 changed files with 697 additions and 339 deletions
+321 -147
View File
@@ -323,37 +323,20 @@ impl JsonRpcConnection {
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (websocket_writer, websocket_reader) = stream.split();
Self::from_websocket_parts(
websocket_writer,
websocket_reader,
connection_label,
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
)
Self::from_websocket_stream(stream, connection_label, /*ping_interval*/ None)
}
pub(crate) fn from_axum_websocket(stream: AxumWebSocket, connection_label: String) -> Self {
let (websocket_writer, websocket_reader) = stream.split();
Self::from_websocket_parts(
websocket_writer,
websocket_reader,
connection_label,
// Axum only wraps inbound exec-server websocket accepts. Outbound websocket clients
// own keepalive pings so one side does not accidentally create redundant traffic.
/*keepalive_interval*/
None,
)
Self::from_websocket_stream(stream, connection_label, Some(WEBSOCKET_KEEPALIVE_INTERVAL))
}
fn from_websocket_parts<W, R, M, E>(
mut websocket_writer: W,
mut websocket_reader: R,
fn from_websocket_stream<T, M, E>(
mut websocket: T,
connection_label: String,
keepalive_interval: Option<Duration>,
ping_interval: Option<Duration>,
) -> Self
where
W: Sink<M, Error = E> + Unpin + Send + 'static,
R: Stream<Item = Result<M, E>> + Unpin + Send + 'static,
T: Sink<M, Error = E> + Stream<Item = Result<M, E>> + Unpin + Send + 'static,
M: JsonRpcWebSocketMessage,
E: std::fmt::Display + Send + 'static,
{
@@ -361,118 +344,106 @@ impl JsonRpcConnection {
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (disconnected_tx, disconnected_rx) = watch::channel(false);
let reader_label = connection_label.clone();
let incoming_tx_for_reader = incoming_tx.clone();
let disconnected_tx_for_reader = disconnected_tx.clone();
let reader_task = tokio::spawn(async move {
let websocket_task = tokio::spawn(async move {
let mut ping_interval = ping_interval.map(|ping_interval| {
let mut interval = tokio::time::interval_at(
tokio::time::Instant::now() + ping_interval,
ping_interval,
);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
interval
});
loop {
match websocket_reader.next().await {
Some(Ok(message)) => match message.parse_jsonrpc_frame() {
Ok(JsonRpcWebSocketFrame::Message(message)) => {
if incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Message(message))
.await
.is_err()
{
break;
}
tokio::select! {
maybe_message = outgoing_rx.recv() => {
let Some(message) = maybe_message else {
break;
};
if let Err(reason) = send_websocket_jsonrpc_message(
&mut websocket,
&connection_label,
&message,
)
.await
{
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
break;
}
Err(err) => {
send_malformed_message(
&incoming_tx_for_reader,
Some(format!(
"failed to parse websocket JSON-RPC message from {reader_label}: {err}"
)),
)
.await;
}
_ = async {
match ping_interval.as_mut() {
Some(interval) => interval.tick().await,
None => std::future::pending().await,
}
Ok(JsonRpcWebSocketFrame::Close) => {
} => {
if let Err(err) = websocket.send(M::ping()).await {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
/*reason*/ None,
&incoming_tx,
&disconnected_tx,
Some(format!(
"failed to write websocket ping to {connection_label}: {err}"
)),
)
.await;
break;
}
Ok(JsonRpcWebSocketFrame::Ignore) => {}
},
Some(Err(err)) => {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
Some(format!(
"failed to read websocket JSON-RPC message from {reader_label}: {err}"
)),
)
.await;
break;
}
None => {
send_disconnected(
&incoming_tx_for_reader,
&disconnected_tx_for_reader,
/*reason*/ None,
)
.await;
break;
}
}
}
});
let writer_task = tokio::spawn(async move {
if let Some(keepalive_interval) = keepalive_interval {
let mut keepalive = tokio::time::interval_at(
tokio::time::Instant::now() + keepalive_interval,
keepalive_interval,
);
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
maybe_message = outgoing_rx.recv() => {
let Some(message) = maybe_message else {
break;
};
if let Err(reason) = send_websocket_jsonrpc_message(
&mut websocket_writer,
&connection_label,
&message,
)
.await
{
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
break;
}
}
_ = keepalive.tick() => {
if let Err(err) = websocket_writer.send(M::ping()).await {
incoming_message = websocket.next() => {
match incoming_message {
Some(Ok(message)) => match message.parse_jsonrpc_frame() {
Ok(JsonRpcWebSocketFrame::Message(message)) => {
if incoming_tx
.send(JsonRpcConnectionEvent::Message(message))
.await
.is_err()
{
break;
}
}
Ok(JsonRpcWebSocketFrame::Close) => {
send_disconnected(
&incoming_tx,
&disconnected_tx,
/*reason*/ None,
)
.await;
break;
}
Ok(JsonRpcWebSocketFrame::Ignore) => {}
Err(err) => {
send_malformed_message(
&incoming_tx,
Some(format!(
"failed to parse websocket JSON-RPC message from {connection_label}: {err}"
)),
)
.await;
}
},
Some(Err(err)) => {
send_disconnected(
&incoming_tx,
&disconnected_tx,
Some(format!(
"failed to write websocket ping to {connection_label}: {err}"
"failed to read websocket JSON-RPC message from {connection_label}: {err}"
)),
)
.await;
break;
}
None => {
send_disconnected(
&incoming_tx,
&disconnected_tx,
/*reason*/ None,
)
.await;
break;
}
}
}
}
} else {
while let Some(message) = outgoing_rx.recv().await {
if let Err(reason) = send_websocket_jsonrpc_message(
&mut websocket_writer,
&connection_label,
&message,
)
.await
{
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
break;
}
}
}
});
@@ -480,7 +451,7 @@ impl JsonRpcConnection {
outgoing_tx,
incoming_rx,
disconnected_rx,
task_handles: vec![reader_task, writer_task],
task_handles: vec![websocket_task],
transport: JsonRpcTransport::Plain,
}
}
@@ -619,34 +590,250 @@ fn serialize_jsonrpc_message(message: &JSONRPCMessage) -> Result<String, serde_j
#[cfg(test)]
mod tests {
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::task::Context;
use std::task::Poll;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::RequestId;
use futures::channel::mpsc as futures_mpsc;
use futures::stream;
use futures::task::Context;
use futures::task::Poll;
use futures::task::AtomicWaker;
use tokio::net::TcpListener;
use tokio::time::timeout;
use tokio_tungstenite::accept_async;
use tokio_tungstenite::connect_async;
use super::*;
struct TestWebSocketSink {
message_tx: futures_mpsc::UnboundedSender<Message>,
#[tokio::test]
async fn websocket_connection_sends_configured_ping() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let connection = JsonRpcConnection::from_websocket_stream(
client_websocket,
"test".into(),
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
);
let message = timeout(Duration::from_secs(1), server_websocket.next())
.await?
.expect("websocket should stay open")?;
assert!(matches!(message, Message::Ping(_)));
drop(connection);
Ok(())
}
impl Sink<Message> for TestWebSocketSink {
#[tokio::test]
async fn websocket_connection_ignores_server_pong() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let mut connection = JsonRpcConnection::from_websocket(client_websocket, "test".into());
server_websocket
.send(Message::Pong(b"check".to_vec().into()))
.await?;
assert!(
timeout(Duration::from_millis(50), connection.incoming_rx.recv())
.await
.is_err()
);
drop(connection);
Ok(())
}
#[tokio::test]
async fn websocket_connection_reports_server_close() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let mut connection = JsonRpcConnection::from_websocket(client_websocket, "test".into());
server_websocket.close(None).await?;
assert!(matches!(
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
Some(JsonRpcConnectionEvent::Disconnected { reason: None })
));
drop(connection);
Ok(())
}
#[tokio::test]
async fn websocket_connection_accepts_binary_jsonrpc_message() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let mut connection = JsonRpcConnection::from_websocket(client_websocket, "test".into());
let message = JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(1),
method: "test".to_string(),
params: None,
trace: None,
});
server_websocket
.send(Message::Binary(serde_json::to_vec(&message)?.into()))
.await?;
assert!(matches!(
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
Some(JsonRpcConnectionEvent::Message(actual)) if actual == message
));
drop(connection);
Ok(())
}
#[tokio::test]
async fn websocket_connection_keeps_outbound_message_while_send_is_backpressured()
-> anyhow::Result<()> {
let (websocket, control, mut outbound_rx) =
ControlledWebSocket::new(/*write_ready*/ false);
let mut connection = JsonRpcConnection::from_websocket_stream(
websocket,
"test".into(),
/*ping_interval*/ None,
);
let message = test_jsonrpc_message();
connection.outgoing_tx.send(message.clone()).await?;
control.wait_for_blocked_write().await?;
control.send_inbound(Message::Pong(b"check".to_vec().into()))?;
assert!(
timeout(Duration::from_millis(50), connection.incoming_rx.recv())
.await
.is_err()
);
control.set_write_ready();
assert!(matches!(
timeout(Duration::from_secs(1), outbound_rx.next()).await?,
Some(Message::Text(text)) if serde_json::from_str::<JSONRPCMessage>(&text)? == message
));
drop(connection);
Ok(())
}
async fn websocket_pair() -> anyhow::Result<(
WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
WebSocketStream<tokio::net::TcpStream>,
)> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let websocket_url = format!("ws://{}", listener.local_addr()?);
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await?;
accept_async(stream).await.map_err(anyhow::Error::from)
});
let (client_websocket, _) = connect_async(websocket_url).await?;
let server_websocket = server_task.await??;
Ok((client_websocket, server_websocket))
}
fn test_jsonrpc_message() -> JSONRPCMessage {
JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(1),
method: "test".to_string(),
params: None,
trace: None,
})
}
struct ControlledWebSocket {
inbound_rx: futures_mpsc::UnboundedReceiver<Result<Message, std::convert::Infallible>>,
outbound_tx: futures_mpsc::UnboundedSender<Message>,
write_ready: Arc<AtomicBool>,
write_blocked: Arc<AtomicBool>,
write_blocked_waker: Arc<AtomicWaker>,
write_waker: Arc<AtomicWaker>,
}
struct ControlledWebSocketHandle {
inbound_tx: futures_mpsc::UnboundedSender<Result<Message, std::convert::Infallible>>,
write_ready: Arc<AtomicBool>,
write_blocked: Arc<AtomicBool>,
write_blocked_waker: Arc<AtomicWaker>,
write_waker: Arc<AtomicWaker>,
}
impl ControlledWebSocket {
fn new(
write_ready: bool,
) -> (
Self,
ControlledWebSocketHandle,
futures_mpsc::UnboundedReceiver<Message>,
) {
let (inbound_tx, inbound_rx) = futures_mpsc::unbounded();
let (outbound_tx, outbound_rx) = futures_mpsc::unbounded();
let write_ready = Arc::new(AtomicBool::new(write_ready));
let write_blocked = Arc::new(AtomicBool::new(false));
let write_blocked_waker = Arc::new(AtomicWaker::new());
let write_waker = Arc::new(AtomicWaker::new());
(
Self {
inbound_rx,
outbound_tx,
write_ready: Arc::clone(&write_ready),
write_blocked: Arc::clone(&write_blocked),
write_blocked_waker: Arc::clone(&write_blocked_waker),
write_waker: Arc::clone(&write_waker),
},
ControlledWebSocketHandle {
inbound_tx,
write_ready,
write_blocked,
write_blocked_waker,
write_waker,
},
outbound_rx,
)
}
}
impl ControlledWebSocketHandle {
fn send_inbound(&self, message: Message) -> anyhow::Result<()> {
self.inbound_tx
.unbounded_send(Ok(message))
.map_err(anyhow::Error::from)
}
fn set_write_ready(&self) {
self.write_ready.store(true, Ordering::Release);
self.write_waker.wake();
}
async fn wait_for_blocked_write(&self) -> anyhow::Result<()> {
timeout(
Duration::from_secs(1),
futures::future::poll_fn(|cx| {
if self.write_blocked.load(Ordering::Acquire) {
Poll::Ready(())
} else {
self.write_blocked_waker.register(cx.waker());
Poll::Pending
}
}),
)
.await?;
Ok(())
}
}
impl Sink<Message> for ControlledWebSocket {
type Error = std::convert::Infallible;
fn poll_ready(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.write_ready.load(Ordering::Acquire) {
Poll::Ready(Ok(()))
} else {
self.write_blocked.store(true, Ordering::Release);
self.write_blocked_waker.wake();
self.write_waker.register(cx.waker());
Poll::Pending
}
}
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
self.get_mut()
.message_tx
self.outbound_tx
.unbounded_send(item)
.expect("test websocket receiver should stay open");
.expect("test outbound receiver should stay open");
Ok(())
}
@@ -665,24 +852,11 @@ mod tests {
}
}
#[tokio::test]
async fn websocket_connection_sends_keepalive_ping() {
let (message_tx, mut message_rx) = futures_mpsc::unbounded::<Message>();
let websocket_writer = TestWebSocketSink { message_tx };
let websocket_reader = stream::pending::<Result<Message, std::convert::Infallible>>();
let connection = JsonRpcConnection::from_websocket_parts(
websocket_writer,
websocket_reader,
"test".into(),
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
);
impl Stream for ControlledWebSocket {
type Item = Result<Message, std::convert::Infallible>;
let message = timeout(Duration::from_secs(1), message_rx.next())
.await
.expect("keepalive ping should arrive before timeout")
.expect("keepalive ping should be sent");
assert!(matches!(message, Message::Ping(_)));
drop(connection);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inbound_rx).poll_next(cx)
}
}
}
+376 -192
View File
@@ -1,7 +1,9 @@
use std::collections::HashMap;
use codex_app_server_protocol::JSONRPCMessage;
use futures::Sink;
use futures::SinkExt;
use futures::Stream;
use futures::StreamExt;
use prost::Message as ProstMessage;
use tokio::io::AsyncRead;
@@ -19,7 +21,6 @@ use crate::connection::CHANNEL_CAPACITY;
use crate::connection::JsonRpcConnection;
use crate::connection::JsonRpcConnectionEvent;
use crate::connection::JsonRpcTransport;
use crate::connection::WEBSOCKET_KEEPALIVE_INTERVAL;
use crate::relay_proto::RelayData;
use crate::relay_proto::RelayMessageFrame;
use crate::relay_proto::RelayResume;
@@ -140,121 +141,25 @@ fn jsonrpc_payload(message: &JSONRPCMessage) -> Result<Vec<u8>, ExecServerError>
serde_json::to_vec(message).map_err(ExecServerError::Json)
}
pub(crate) fn harness_connection_from_websocket<S>(
stream: WebSocketStream<S>,
pub(crate) fn harness_connection_from_websocket<T, E>(
stream: T,
connection_label: String,
) -> JsonRpcConnection
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
T: Sink<Message, Error = E> + Stream<Item = Result<Message, E>> + Unpin + Send + 'static,
E: std::fmt::Display + Send + 'static,
{
let stream_id = Uuid::new_v4().to_string();
let (mut websocket_writer, mut websocket_reader) = stream.split();
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (disconnected_tx, disconnected_rx) = watch::channel(false);
let reader_label = connection_label;
let reader_stream_id = stream_id.clone();
let incoming_tx_for_reader = incoming_tx;
let disconnected_tx_for_reader = disconnected_tx.clone();
let reader_task = tokio::spawn(async move {
loop {
match websocket_reader.next().await {
Some(Ok(Message::Binary(payload))) => {
let frame = match decode_relay_message_frame(payload.as_ref()) {
Ok(frame) => frame,
Err(err) => {
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: format!(
"failed to parse relay message frame from {reader_label}: {err}"
),
})
.await;
continue;
}
};
if frame.stream_id != reader_stream_id {
continue;
}
let kind = match frame.validate() {
Ok(kind) => kind,
Err(err) => {
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: err.to_string(),
})
.await;
continue;
}
};
match kind {
RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() {
Ok(message) => {
if incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Message(message))
.await
.is_err()
{
break;
}
}
Err(err) => {
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: err.to_string(),
})
.await;
}
},
RelayFrameBodyKind::Reset => {
let _ = disconnected_tx_for_reader.send(true);
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Disconnected {
reason: frame.into_reset_reason(),
})
.await;
break;
}
RelayFrameBodyKind::Ack
| RelayFrameBodyKind::Resume
| RelayFrameBodyKind::Heartbeat => {}
}
}
Some(Ok(Message::Close(_))) | None => {
let _ = disconnected_tx_for_reader.send(true);
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Disconnected { reason: None })
.await;
break;
}
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {}
Some(Ok(Message::Text(_))) => {
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: "relay exec-server transport expects binary protobuf frames"
.to_string(),
})
.await;
}
Some(Err(err)) => {
let _ = disconnected_tx_for_reader.send(true);
let _ = incoming_tx_for_reader
.send(JsonRpcConnectionEvent::Disconnected {
reason: Some(format!(
"failed to read relay websocket frame from {reader_label}: {err}"
)),
})
.await;
break;
}
}
}
});
let writer_task = tokio::spawn(async move {
let websocket_task = tokio::spawn(async move {
let mut websocket = stream;
let reader_label = connection_label;
let reader_stream_id = stream_id.clone();
let resume = RelayMessageFrame::resume(stream_id.clone());
if websocket_writer
if websocket
.send(Message::Binary(encode_relay_message_frame(&resume).into()))
.await
.is_err()
@@ -263,11 +168,6 @@ where
return;
}
let mut keepalive = tokio::time::interval_at(
tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL,
WEBSOCKET_KEEPALIVE_INTERVAL,
);
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut next_seq = 0u32;
loop {
tokio::select! {
@@ -284,7 +184,7 @@ where
};
let frame = RelayMessageFrame::data(stream_id.clone(), next_seq, payload);
next_seq = next_seq.wrapping_add(1);
if websocket_writer
if websocket
.send(Message::Binary(encode_relay_message_frame(&frame).into()))
.await
.is_err()
@@ -293,10 +193,96 @@ where
break;
}
}
_ = keepalive.tick() => {
if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() {
let _ = disconnected_tx.send(true);
break;
incoming_message = websocket.next() => {
match incoming_message {
Some(Ok(Message::Binary(payload))) => {
let frame = match decode_relay_message_frame(payload.as_ref()) {
Ok(frame) => frame,
Err(err) => {
let _ = incoming_tx
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: format!(
"failed to parse relay message frame from {reader_label}: {err}"
),
})
.await;
continue;
}
};
if frame.stream_id != reader_stream_id {
continue;
}
let kind = match frame.validate() {
Ok(kind) => kind,
Err(err) => {
let _ = incoming_tx
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: err.to_string(),
})
.await;
continue;
}
};
match kind {
RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() {
Ok(message) => {
if incoming_tx
.send(JsonRpcConnectionEvent::Message(message))
.await
.is_err()
{
break;
}
}
Err(err) => {
let _ = incoming_tx
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: err.to_string(),
})
.await;
}
},
RelayFrameBodyKind::Reset => {
let _ = disconnected_tx.send(true);
let _ = incoming_tx
.send(JsonRpcConnectionEvent::Disconnected {
reason: frame.into_reset_reason(),
})
.await;
break;
}
RelayFrameBodyKind::Ack
| RelayFrameBodyKind::Resume
| RelayFrameBodyKind::Heartbeat => {}
}
}
Some(Ok(Message::Close(_))) | None => {
let _ = disconnected_tx.send(true);
let _ = incoming_tx
.send(JsonRpcConnectionEvent::Disconnected { reason: None })
.await;
break;
}
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {}
Some(Ok(Message::Text(_))) => {
let _ = incoming_tx
.send(JsonRpcConnectionEvent::MalformedMessage {
reason: "relay exec-server transport expects binary protobuf frames"
.to_string(),
})
.await;
}
Some(Err(err)) => {
let _ = disconnected_tx.send(true);
let _ = incoming_tx
.send(JsonRpcConnectionEvent::Disconnected {
reason: Some(format!(
"failed to read relay websocket frame from {reader_label}: {err}"
)),
})
.await;
break;
}
}
}
}
@@ -307,7 +293,7 @@ where
outgoing_tx,
incoming_rx,
disconnected_rx,
task_handles: vec![reader_task, writer_task],
task_handles: vec![websocket_task],
transport: JsonRpcTransport::Plain,
}
}
@@ -318,59 +304,42 @@ pub(crate) async fn run_multiplexed_executor<S>(
) where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (mut websocket_writer, mut websocket_reader) = stream.split();
let mut websocket = stream;
let (physical_outgoing_tx, mut physical_outgoing_rx) =
mpsc::channel::<Vec<u8>>(CHANNEL_CAPACITY);
let writer_task = tokio::spawn(async move {
let mut keepalive = tokio::time::interval_at(
tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL,
WEBSOCKET_KEEPALIVE_INTERVAL,
);
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
maybe_encoded = physical_outgoing_rx.recv() => {
let Some(encoded) = maybe_encoded else {
break;
};
if websocket_writer
.send(Message::Binary(encoded.into()))
.await
.is_err()
{
break;
}
}
_ = keepalive.tick() => {
if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() {
break;
}
}
}
}
});
let mut streams: HashMap<String, VirtualStream> = HashMap::new();
loop {
let frame = match websocket_reader.next().await {
Some(Ok(Message::Binary(payload))) => {
match decode_relay_message_frame(payload.as_ref()) {
Ok(frame) => frame,
Err(err) => {
warn!("dropping malformed relay message frame from harness: {err}");
continue;
}
let frame = tokio::select! {
maybe_encoded = physical_outgoing_rx.recv() => {
let Some(encoded) = maybe_encoded else {
break;
};
if websocket.send(Message::Binary(encoded.into())).await.is_err() {
break;
}
}
Some(Ok(Message::Close(_))) | None => break,
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue,
Some(Ok(Message::Text(_))) => {
warn!("dropping non-binary relay message frame from harness");
continue;
}
Some(Err(err)) => {
debug!("multiplexed executor websocket read failed: {err}");
break;
incoming_message = websocket.next() => match incoming_message {
Some(Ok(Message::Binary(payload))) => {
match decode_relay_message_frame(payload.as_ref()) {
Ok(frame) => frame,
Err(err) => {
warn!("dropping malformed relay message frame from harness: {err}");
continue;
}
}
}
Some(Ok(Message::Close(_))) | None => break,
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue,
Some(Ok(Message::Text(_))) => {
warn!("dropping non-binary relay message frame from harness");
continue;
}
Some(Err(err)) => {
debug!("multiplexed executor websocket read failed: {err}");
break;
}
}
};
@@ -423,7 +392,6 @@ pub(crate) async fn run_multiplexed_executor<S>(
stream.disconnect(/*reason*/ None).await;
}
drop(physical_outgoing_tx);
let _ = writer_task.await;
}
struct VirtualStream {
@@ -492,8 +460,20 @@ fn spawn_virtual_stream(
#[cfg(test)]
mod tests {
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use codex_app_server_protocol::JSONRPCRequest;
use codex_app_server_protocol::RequestId;
use futures::Sink;
use futures::Stream;
use futures::channel::mpsc as futures_mpsc;
use futures::task::AtomicWaker;
use tokio::net::TcpListener;
use tokio::time::timeout;
use tokio_tungstenite::accept_async;
@@ -502,40 +482,107 @@ mod tests {
use super::*;
fn test_runtime_paths() -> anyhow::Result<crate::ExecServerRuntimePaths> {
crate::ExecServerRuntimePaths::new(
std::env::current_exe()?,
/*codex_linux_sandbox_exe*/ None,
)
.map_err(anyhow::Error::from)
}
#[tokio::test]
async fn multiplexed_executor_sends_keepalive_ping() -> anyhow::Result<()> {
async fn harness_connection_receives_relay_data() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let executor_task = tokio::spawn(run_multiplexed_executor(
client_websocket,
ConnectionProcessor::new(test_runtime_paths()?),
let mut connection =
harness_connection_from_websocket(client_websocket, "test".to_string());
let stream_id = read_resume_stream_id(&mut server_websocket).await?;
let message = test_jsonrpc_message();
server_websocket
.send(Message::Binary(
encode_relay_message_frame(&RelayMessageFrame::data(
stream_id,
/*seq*/ 0,
jsonrpc_payload(&message)?,
))
.into(),
))
.await?;
assert!(matches!(
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
Some(JsonRpcConnectionEvent::Message(actual)) if actual == message
));
read_keepalive_ping(&mut server_websocket).await?;
executor_task.abort();
let _ = executor_task.await;
drop(connection);
Ok(())
}
#[tokio::test]
async fn harness_connection_sends_keepalive_ping() -> anyhow::Result<()> {
async fn harness_connection_reports_text_frames_as_malformed() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let connection = harness_connection_from_websocket(client_websocket, "test".to_string());
let mut connection =
harness_connection_from_websocket(client_websocket, "test".to_string());
read_keepalive_ping(&mut server_websocket).await?;
read_resume_stream_id(&mut server_websocket).await?;
server_websocket.send(Message::Text("nope".into())).await?;
assert!(matches!(
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
Some(JsonRpcConnectionEvent::MalformedMessage { reason })
if reason == "relay exec-server transport expects binary protobuf frames"
));
drop(connection);
Ok(())
}
#[tokio::test]
async fn harness_connection_reports_server_close() -> anyhow::Result<()> {
let (client_websocket, mut server_websocket) = websocket_pair().await?;
let mut connection =
harness_connection_from_websocket(client_websocket, "test".to_string());
read_resume_stream_id(&mut server_websocket).await?;
server_websocket.close(None).await?;
assert!(matches!(
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
Some(JsonRpcConnectionEvent::Disconnected { reason: None })
));
drop(connection);
Ok(())
}
#[tokio::test]
async fn harness_connection_keeps_outbound_frame_while_send_is_backpressured()
-> anyhow::Result<()> {
let (websocket, control, mut outbound_rx) =
ControlledWebSocket::new(/*write_ready*/ true);
let mut connection = harness_connection_from_websocket(websocket, "test".to_string());
let Message::Binary(resume_payload) = timeout(Duration::from_secs(1), outbound_rx.next())
.await?
.expect("resume frame")
else {
anyhow::bail!("expected relay resume frame");
};
let stream_id = decode_relay_message_frame(resume_payload.as_ref())?.stream_id;
let message = test_jsonrpc_message();
control.set_write_blocked();
connection.outgoing_tx.send(message.clone()).await?;
control.wait_for_blocked_write().await?;
control.send_inbound(Message::Pong(b"check".to_vec().into()))?;
assert!(
timeout(Duration::from_millis(50), connection.incoming_rx.recv())
.await
.is_err()
);
control.set_write_ready();
let Message::Binary(data_payload) = timeout(Duration::from_secs(1), outbound_rx.next())
.await?
.expect("data frame")
else {
anyhow::bail!("expected relay data frame");
};
let frame = decode_relay_message_frame(data_payload.as_ref())?;
assert_eq!(frame.stream_id, stream_id);
assert_eq!(frame.into_jsonrpc_message()?, message);
drop(connection);
Ok(())
}
async fn websocket_pair() -> anyhow::Result<(
WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
WebSocketStream<tokio::net::TcpStream>,
@@ -551,18 +598,155 @@ mod tests {
Ok((client_websocket, server_websocket))
}
async fn read_keepalive_ping(
async fn read_resume_stream_id(
websocket: &mut WebSocketStream<tokio::net::TcpStream>,
) -> anyhow::Result<()> {
loop {
let Some(message) = timeout(Duration::from_secs(1), websocket.next()).await? else {
anyhow::bail!("websocket closed before keepalive ping");
};
match message? {
Message::Ping(_) => return Ok(()),
Message::Binary(_) | Message::Text(_) | Message::Pong(_) | Message::Frame(_) => {}
Message::Close(_) => anyhow::bail!("websocket closed before keepalive ping"),
) -> anyhow::Result<String> {
let message = timeout(Duration::from_secs(1), websocket.next())
.await?
.expect("websocket should stay open")?;
let Message::Binary(payload) = message else {
anyhow::bail!("expected relay resume frame, got {message:?}");
};
let frame = decode_relay_message_frame(payload.as_ref())?;
assert_eq!(frame.validate()?, RelayFrameBodyKind::Resume);
Ok(frame.stream_id)
}
fn test_jsonrpc_message() -> JSONRPCMessage {
JSONRPCMessage::Request(JSONRPCRequest {
id: RequestId::Integer(1),
method: "test".to_string(),
params: None,
trace: None,
})
}
struct ControlledWebSocket {
inbound_rx: futures_mpsc::UnboundedReceiver<Result<Message, std::convert::Infallible>>,
outbound_tx: futures_mpsc::UnboundedSender<Message>,
write_ready: Arc<AtomicBool>,
write_blocked: Arc<AtomicBool>,
write_blocked_waker: Arc<AtomicWaker>,
write_waker: Arc<AtomicWaker>,
}
struct ControlledWebSocketHandle {
inbound_tx: futures_mpsc::UnboundedSender<Result<Message, std::convert::Infallible>>,
write_ready: Arc<AtomicBool>,
write_blocked: Arc<AtomicBool>,
write_blocked_waker: Arc<AtomicWaker>,
write_waker: Arc<AtomicWaker>,
}
impl ControlledWebSocket {
fn new(
write_ready: bool,
) -> (
Self,
ControlledWebSocketHandle,
futures_mpsc::UnboundedReceiver<Message>,
) {
let (inbound_tx, inbound_rx) = futures_mpsc::unbounded();
let (outbound_tx, outbound_rx) = futures_mpsc::unbounded();
let write_ready = Arc::new(AtomicBool::new(write_ready));
let write_blocked = Arc::new(AtomicBool::new(false));
let write_blocked_waker = Arc::new(AtomicWaker::new());
let write_waker = Arc::new(AtomicWaker::new());
(
Self {
inbound_rx,
outbound_tx,
write_ready: Arc::clone(&write_ready),
write_blocked: Arc::clone(&write_blocked),
write_blocked_waker: Arc::clone(&write_blocked_waker),
write_waker: Arc::clone(&write_waker),
},
ControlledWebSocketHandle {
inbound_tx,
write_ready,
write_blocked,
write_blocked_waker,
write_waker,
},
outbound_rx,
)
}
}
impl ControlledWebSocketHandle {
fn send_inbound(&self, message: Message) -> anyhow::Result<()> {
self.inbound_tx
.unbounded_send(Ok(message))
.map_err(anyhow::Error::from)
}
fn set_write_blocked(&self) {
self.write_ready.store(false, Ordering::Release);
}
fn set_write_ready(&self) {
self.write_ready.store(true, Ordering::Release);
self.write_waker.wake();
}
async fn wait_for_blocked_write(&self) -> anyhow::Result<()> {
timeout(
Duration::from_secs(1),
futures::future::poll_fn(|cx| {
if self.write_blocked.load(Ordering::Acquire) {
Poll::Ready(())
} else {
self.write_blocked_waker.register(cx.waker());
Poll::Pending
}
}),
)
.await?;
Ok(())
}
}
impl Sink<Message> for ControlledWebSocket {
type Error = std::convert::Infallible;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.write_ready.load(Ordering::Acquire) {
Poll::Ready(Ok(()))
} else {
self.write_blocked.store(true, Ordering::Release);
self.write_blocked_waker.wake();
self.write_waker.register(cx.waker());
Poll::Pending
}
}
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
self.outbound_tx
.unbounded_send(item)
.expect("test outbound receiver should stay open");
Ok(())
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
impl Stream for ControlledWebSocket {
type Item = Result<Message, std::convert::Infallible>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inbound_rx).poll_next(cx)
}
}
}