mirror of
https://github.com/pchuan98/codex.git
synced 2026-07-01 00:31:56 +08:00
Refactor exec-server websocket pump (#23327)
## Why Exec-server websocket handling had separate reader and writer tasks for the same socket. That made websocket control-frame handling asymmetric: the task reading frames could observe `Ping`, but the task allowed to write frames was elsewhere. This PR moves each physical websocket onto one always-running pump so the socket owner can handle application frames and websocket control frames together. ## What changed - Refactored direct exec-server websocket connections in `connection.rs` to use one task that owns the websocket for outbound JSON-RPC, inbound JSON-RPC, periodic keepalive pings, and `Ping` -> `Pong` replies. - Refactored relay websocket handling in `relay.rs` the same way for both the harness-side logical connection and the multiplexed executor physical socket. - Preserved the existing keepalive ownership policy: outbound direct websocket clients still send periodic pings, inbound Axum accepts only reply with pongs, and relay physical websocket endpoints keep their existing periodic pings. - Added focused websocket pump tests for ping/pong, binary JSON-RPC, relay data, malformed relay text frames, and close/disconnect behavior. - Reconnect behavior is intentionally left for a follow-up. ## Validation - Devbox Bazel focused unit target: - `//codex-rs/exec-server:exec-server-unit-tests --test_filter='websocket_connection_|harness_connection_|multiplexed_executor_'`
This commit is contained in:
committed by
GitHub
Unverified
parent
5c43a64e2b
commit
83af3abc68
@@ -323,37 +323,20 @@ impl JsonRpcConnection {
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let (websocket_writer, websocket_reader) = stream.split();
|
||||
Self::from_websocket_parts(
|
||||
websocket_writer,
|
||||
websocket_reader,
|
||||
connection_label,
|
||||
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
|
||||
)
|
||||
Self::from_websocket_stream(stream, connection_label, /*ping_interval*/ None)
|
||||
}
|
||||
|
||||
pub(crate) fn from_axum_websocket(stream: AxumWebSocket, connection_label: String) -> Self {
|
||||
let (websocket_writer, websocket_reader) = stream.split();
|
||||
Self::from_websocket_parts(
|
||||
websocket_writer,
|
||||
websocket_reader,
|
||||
connection_label,
|
||||
// Axum only wraps inbound exec-server websocket accepts. Outbound websocket clients
|
||||
// own keepalive pings so one side does not accidentally create redundant traffic.
|
||||
/*keepalive_interval*/
|
||||
None,
|
||||
)
|
||||
Self::from_websocket_stream(stream, connection_label, Some(WEBSOCKET_KEEPALIVE_INTERVAL))
|
||||
}
|
||||
|
||||
fn from_websocket_parts<W, R, M, E>(
|
||||
mut websocket_writer: W,
|
||||
mut websocket_reader: R,
|
||||
fn from_websocket_stream<T, M, E>(
|
||||
mut websocket: T,
|
||||
connection_label: String,
|
||||
keepalive_interval: Option<Duration>,
|
||||
ping_interval: Option<Duration>,
|
||||
) -> Self
|
||||
where
|
||||
W: Sink<M, Error = E> + Unpin + Send + 'static,
|
||||
R: Stream<Item = Result<M, E>> + Unpin + Send + 'static,
|
||||
T: Sink<M, Error = E> + Stream<Item = Result<M, E>> + Unpin + Send + 'static,
|
||||
M: JsonRpcWebSocketMessage,
|
||||
E: std::fmt::Display + Send + 'static,
|
||||
{
|
||||
@@ -361,118 +344,106 @@ impl JsonRpcConnection {
|
||||
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (disconnected_tx, disconnected_rx) = watch::channel(false);
|
||||
|
||||
let reader_label = connection_label.clone();
|
||||
let incoming_tx_for_reader = incoming_tx.clone();
|
||||
let disconnected_tx_for_reader = disconnected_tx.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
let websocket_task = tokio::spawn(async move {
|
||||
let mut ping_interval = ping_interval.map(|ping_interval| {
|
||||
let mut interval = tokio::time::interval_at(
|
||||
tokio::time::Instant::now() + ping_interval,
|
||||
ping_interval,
|
||||
);
|
||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
interval
|
||||
});
|
||||
|
||||
loop {
|
||||
match websocket_reader.next().await {
|
||||
Some(Ok(message)) => match message.parse_jsonrpc_frame() {
|
||||
Ok(JsonRpcWebSocketFrame::Message(message)) => {
|
||||
if incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Message(message))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
tokio::select! {
|
||||
maybe_message = outgoing_rx.recv() => {
|
||||
let Some(message) = maybe_message else {
|
||||
break;
|
||||
};
|
||||
if let Err(reason) = send_websocket_jsonrpc_message(
|
||||
&mut websocket,
|
||||
&connection_label,
|
||||
&message,
|
||||
)
|
||||
.await
|
||||
{
|
||||
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
send_malformed_message(
|
||||
&incoming_tx_for_reader,
|
||||
Some(format!(
|
||||
"failed to parse websocket JSON-RPC message from {reader_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
_ = async {
|
||||
match ping_interval.as_mut() {
|
||||
Some(interval) => interval.tick().await,
|
||||
None => std::future::pending().await,
|
||||
}
|
||||
Ok(JsonRpcWebSocketFrame::Close) => {
|
||||
} => {
|
||||
if let Err(err) = websocket.send(M::ping()).await {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to write websocket ping to {connection_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Ok(JsonRpcWebSocketFrame::Ignore) => {}
|
||||
},
|
||||
Some(Err(err)) => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
Some(format!(
|
||||
"failed to read websocket JSON-RPC message from {reader_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
send_disconnected(
|
||||
&incoming_tx_for_reader,
|
||||
&disconnected_tx_for_reader,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let writer_task = tokio::spawn(async move {
|
||||
if let Some(keepalive_interval) = keepalive_interval {
|
||||
let mut keepalive = tokio::time::interval_at(
|
||||
tokio::time::Instant::now() + keepalive_interval,
|
||||
keepalive_interval,
|
||||
);
|
||||
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
tokio::select! {
|
||||
maybe_message = outgoing_rx.recv() => {
|
||||
let Some(message) = maybe_message else {
|
||||
break;
|
||||
};
|
||||
if let Err(reason) = send_websocket_jsonrpc_message(
|
||||
&mut websocket_writer,
|
||||
&connection_label,
|
||||
&message,
|
||||
)
|
||||
.await
|
||||
{
|
||||
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = keepalive.tick() => {
|
||||
if let Err(err) = websocket_writer.send(M::ping()).await {
|
||||
incoming_message = websocket.next() => {
|
||||
match incoming_message {
|
||||
Some(Ok(message)) => match message.parse_jsonrpc_frame() {
|
||||
Ok(JsonRpcWebSocketFrame::Message(message)) => {
|
||||
if incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Message(message))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(JsonRpcWebSocketFrame::Close) => {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Ok(JsonRpcWebSocketFrame::Ignore) => {}
|
||||
Err(err) => {
|
||||
send_malformed_message(
|
||||
&incoming_tx,
|
||||
Some(format!(
|
||||
"failed to parse websocket JSON-RPC message from {connection_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
},
|
||||
Some(Err(err)) => {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
Some(format!(
|
||||
"failed to write websocket ping to {connection_label}: {err}"
|
||||
"failed to read websocket JSON-RPC message from {connection_label}: {err}"
|
||||
)),
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
send_disconnected(
|
||||
&incoming_tx,
|
||||
&disconnected_tx,
|
||||
/*reason*/ None,
|
||||
)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
while let Some(message) = outgoing_rx.recv().await {
|
||||
if let Err(reason) = send_websocket_jsonrpc_message(
|
||||
&mut websocket_writer,
|
||||
&connection_label,
|
||||
&message,
|
||||
)
|
||||
.await
|
||||
{
|
||||
send_disconnected(&incoming_tx, &disconnected_tx, Some(reason)).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -480,7 +451,7 @@ impl JsonRpcConnection {
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
task_handles: vec![websocket_task],
|
||||
transport: JsonRpcTransport::Plain,
|
||||
}
|
||||
}
|
||||
@@ -619,34 +590,250 @@ fn serialize_jsonrpc_message(message: &JSONRPCMessage) -> Result<String, serde_j
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use futures::channel::mpsc as futures_mpsc;
|
||||
use futures::stream;
|
||||
use futures::task::Context;
|
||||
use futures::task::Poll;
|
||||
use futures::task::AtomicWaker;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::accept_async;
|
||||
use tokio_tungstenite::connect_async;
|
||||
|
||||
use super::*;
|
||||
|
||||
struct TestWebSocketSink {
|
||||
message_tx: futures_mpsc::UnboundedSender<Message>,
|
||||
#[tokio::test]
|
||||
async fn websocket_connection_sends_configured_ping() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let connection = JsonRpcConnection::from_websocket_stream(
|
||||
client_websocket,
|
||||
"test".into(),
|
||||
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
|
||||
);
|
||||
|
||||
let message = timeout(Duration::from_secs(1), server_websocket.next())
|
||||
.await?
|
||||
.expect("websocket should stay open")?;
|
||||
assert!(matches!(message, Message::Ping(_)));
|
||||
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl Sink<Message> for TestWebSocketSink {
|
||||
#[tokio::test]
|
||||
async fn websocket_connection_ignores_server_pong() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let mut connection = JsonRpcConnection::from_websocket(client_websocket, "test".into());
|
||||
|
||||
server_websocket
|
||||
.send(Message::Pong(b"check".to_vec().into()))
|
||||
.await?;
|
||||
assert!(
|
||||
timeout(Duration::from_millis(50), connection.incoming_rx.recv())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn websocket_connection_reports_server_close() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let mut connection = JsonRpcConnection::from_websocket(client_websocket, "test".into());
|
||||
|
||||
server_websocket.close(None).await?;
|
||||
assert!(matches!(
|
||||
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
|
||||
Some(JsonRpcConnectionEvent::Disconnected { reason: None })
|
||||
));
|
||||
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn websocket_connection_accepts_binary_jsonrpc_message() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let mut connection = JsonRpcConnection::from_websocket(client_websocket, "test".into());
|
||||
let message = JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(1),
|
||||
method: "test".to_string(),
|
||||
params: None,
|
||||
trace: None,
|
||||
});
|
||||
|
||||
server_websocket
|
||||
.send(Message::Binary(serde_json::to_vec(&message)?.into()))
|
||||
.await?;
|
||||
assert!(matches!(
|
||||
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
|
||||
Some(JsonRpcConnectionEvent::Message(actual)) if actual == message
|
||||
));
|
||||
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn websocket_connection_keeps_outbound_message_while_send_is_backpressured()
|
||||
-> anyhow::Result<()> {
|
||||
let (websocket, control, mut outbound_rx) =
|
||||
ControlledWebSocket::new(/*write_ready*/ false);
|
||||
let mut connection = JsonRpcConnection::from_websocket_stream(
|
||||
websocket,
|
||||
"test".into(),
|
||||
/*ping_interval*/ None,
|
||||
);
|
||||
let message = test_jsonrpc_message();
|
||||
|
||||
connection.outgoing_tx.send(message.clone()).await?;
|
||||
control.wait_for_blocked_write().await?;
|
||||
control.send_inbound(Message::Pong(b"check".to_vec().into()))?;
|
||||
assert!(
|
||||
timeout(Duration::from_millis(50), connection.incoming_rx.recv())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
control.set_write_ready();
|
||||
assert!(matches!(
|
||||
timeout(Duration::from_secs(1), outbound_rx.next()).await?,
|
||||
Some(Message::Text(text)) if serde_json::from_str::<JSONRPCMessage>(&text)? == message
|
||||
));
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn websocket_pair() -> anyhow::Result<(
|
||||
WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
|
||||
WebSocketStream<tokio::net::TcpStream>,
|
||||
)> {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await?;
|
||||
let websocket_url = format!("ws://{}", listener.local_addr()?);
|
||||
let server_task = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
accept_async(stream).await.map_err(anyhow::Error::from)
|
||||
});
|
||||
let (client_websocket, _) = connect_async(websocket_url).await?;
|
||||
let server_websocket = server_task.await??;
|
||||
Ok((client_websocket, server_websocket))
|
||||
}
|
||||
|
||||
fn test_jsonrpc_message() -> JSONRPCMessage {
|
||||
JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(1),
|
||||
method: "test".to_string(),
|
||||
params: None,
|
||||
trace: None,
|
||||
})
|
||||
}
|
||||
|
||||
struct ControlledWebSocket {
|
||||
inbound_rx: futures_mpsc::UnboundedReceiver<Result<Message, std::convert::Infallible>>,
|
||||
outbound_tx: futures_mpsc::UnboundedSender<Message>,
|
||||
write_ready: Arc<AtomicBool>,
|
||||
write_blocked: Arc<AtomicBool>,
|
||||
write_blocked_waker: Arc<AtomicWaker>,
|
||||
write_waker: Arc<AtomicWaker>,
|
||||
}
|
||||
|
||||
struct ControlledWebSocketHandle {
|
||||
inbound_tx: futures_mpsc::UnboundedSender<Result<Message, std::convert::Infallible>>,
|
||||
write_ready: Arc<AtomicBool>,
|
||||
write_blocked: Arc<AtomicBool>,
|
||||
write_blocked_waker: Arc<AtomicWaker>,
|
||||
write_waker: Arc<AtomicWaker>,
|
||||
}
|
||||
|
||||
impl ControlledWebSocket {
|
||||
fn new(
|
||||
write_ready: bool,
|
||||
) -> (
|
||||
Self,
|
||||
ControlledWebSocketHandle,
|
||||
futures_mpsc::UnboundedReceiver<Message>,
|
||||
) {
|
||||
let (inbound_tx, inbound_rx) = futures_mpsc::unbounded();
|
||||
let (outbound_tx, outbound_rx) = futures_mpsc::unbounded();
|
||||
let write_ready = Arc::new(AtomicBool::new(write_ready));
|
||||
let write_blocked = Arc::new(AtomicBool::new(false));
|
||||
let write_blocked_waker = Arc::new(AtomicWaker::new());
|
||||
let write_waker = Arc::new(AtomicWaker::new());
|
||||
(
|
||||
Self {
|
||||
inbound_rx,
|
||||
outbound_tx,
|
||||
write_ready: Arc::clone(&write_ready),
|
||||
write_blocked: Arc::clone(&write_blocked),
|
||||
write_blocked_waker: Arc::clone(&write_blocked_waker),
|
||||
write_waker: Arc::clone(&write_waker),
|
||||
},
|
||||
ControlledWebSocketHandle {
|
||||
inbound_tx,
|
||||
write_ready,
|
||||
write_blocked,
|
||||
write_blocked_waker,
|
||||
write_waker,
|
||||
},
|
||||
outbound_rx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ControlledWebSocketHandle {
|
||||
fn send_inbound(&self, message: Message) -> anyhow::Result<()> {
|
||||
self.inbound_tx
|
||||
.unbounded_send(Ok(message))
|
||||
.map_err(anyhow::Error::from)
|
||||
}
|
||||
|
||||
fn set_write_ready(&self) {
|
||||
self.write_ready.store(true, Ordering::Release);
|
||||
self.write_waker.wake();
|
||||
}
|
||||
|
||||
async fn wait_for_blocked_write(&self) -> anyhow::Result<()> {
|
||||
timeout(
|
||||
Duration::from_secs(1),
|
||||
futures::future::poll_fn(|cx| {
|
||||
if self.write_blocked.load(Ordering::Acquire) {
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
self.write_blocked_waker.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<Message> for ControlledWebSocket {
|
||||
type Error = std::convert::Infallible;
|
||||
|
||||
fn poll_ready(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
if self.write_ready.load(Ordering::Acquire) {
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
self.write_blocked.store(true, Ordering::Release);
|
||||
self.write_blocked_waker.wake();
|
||||
self.write_waker.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
|
||||
self.get_mut()
|
||||
.message_tx
|
||||
self.outbound_tx
|
||||
.unbounded_send(item)
|
||||
.expect("test websocket receiver should stay open");
|
||||
.expect("test outbound receiver should stay open");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -665,24 +852,11 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn websocket_connection_sends_keepalive_ping() {
|
||||
let (message_tx, mut message_rx) = futures_mpsc::unbounded::<Message>();
|
||||
let websocket_writer = TestWebSocketSink { message_tx };
|
||||
let websocket_reader = stream::pending::<Result<Message, std::convert::Infallible>>();
|
||||
let connection = JsonRpcConnection::from_websocket_parts(
|
||||
websocket_writer,
|
||||
websocket_reader,
|
||||
"test".into(),
|
||||
Some(WEBSOCKET_KEEPALIVE_INTERVAL),
|
||||
);
|
||||
impl Stream for ControlledWebSocket {
|
||||
type Item = Result<Message, std::convert::Infallible>;
|
||||
|
||||
let message = timeout(Duration::from_secs(1), message_rx.next())
|
||||
.await
|
||||
.expect("keepalive ping should arrive before timeout")
|
||||
.expect("keepalive ping should be sent");
|
||||
assert!(matches!(message, Message::Ping(_)));
|
||||
|
||||
drop(connection);
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
Pin::new(&mut self.inbound_rx).poll_next(cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+376
-192
@@ -1,7 +1,9 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use futures::Sink;
|
||||
use futures::SinkExt;
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
use prost::Message as ProstMessage;
|
||||
use tokio::io::AsyncRead;
|
||||
@@ -19,7 +21,6 @@ use crate::connection::CHANNEL_CAPACITY;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::connection::JsonRpcConnectionEvent;
|
||||
use crate::connection::JsonRpcTransport;
|
||||
use crate::connection::WEBSOCKET_KEEPALIVE_INTERVAL;
|
||||
use crate::relay_proto::RelayData;
|
||||
use crate::relay_proto::RelayMessageFrame;
|
||||
use crate::relay_proto::RelayResume;
|
||||
@@ -140,121 +141,25 @@ fn jsonrpc_payload(message: &JSONRPCMessage) -> Result<Vec<u8>, ExecServerError>
|
||||
serde_json::to_vec(message).map_err(ExecServerError::Json)
|
||||
}
|
||||
|
||||
pub(crate) fn harness_connection_from_websocket<S>(
|
||||
stream: WebSocketStream<S>,
|
||||
pub(crate) fn harness_connection_from_websocket<T, E>(
|
||||
stream: T,
|
||||
connection_label: String,
|
||||
) -> JsonRpcConnection
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
T: Sink<Message, Error = E> + Stream<Item = Result<Message, E>> + Unpin + Send + 'static,
|
||||
E: std::fmt::Display + Send + 'static,
|
||||
{
|
||||
let stream_id = Uuid::new_v4().to_string();
|
||||
let (mut websocket_writer, mut websocket_reader) = stream.split();
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let (disconnected_tx, disconnected_rx) = watch::channel(false);
|
||||
|
||||
let reader_label = connection_label;
|
||||
let reader_stream_id = stream_id.clone();
|
||||
let incoming_tx_for_reader = incoming_tx;
|
||||
let disconnected_tx_for_reader = disconnected_tx.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match websocket_reader.next().await {
|
||||
Some(Ok(Message::Binary(payload))) => {
|
||||
let frame = match decode_relay_message_frame(payload.as_ref()) {
|
||||
Ok(frame) => frame,
|
||||
Err(err) => {
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: format!(
|
||||
"failed to parse relay message frame from {reader_label}: {err}"
|
||||
),
|
||||
})
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if frame.stream_id != reader_stream_id {
|
||||
continue;
|
||||
}
|
||||
let kind = match frame.validate() {
|
||||
Ok(kind) => kind,
|
||||
Err(err) => {
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
match kind {
|
||||
RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() {
|
||||
Ok(message) => {
|
||||
if incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Message(message))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
},
|
||||
RelayFrameBodyKind::Reset => {
|
||||
let _ = disconnected_tx_for_reader.send(true);
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Disconnected {
|
||||
reason: frame.into_reset_reason(),
|
||||
})
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
RelayFrameBodyKind::Ack
|
||||
| RelayFrameBodyKind::Resume
|
||||
| RelayFrameBodyKind::Heartbeat => {}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
let _ = disconnected_tx_for_reader.send(true);
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Disconnected { reason: None })
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {}
|
||||
Some(Ok(Message::Text(_))) => {
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: "relay exec-server transport expects binary protobuf frames"
|
||||
.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
let _ = disconnected_tx_for_reader.send(true);
|
||||
let _ = incoming_tx_for_reader
|
||||
.send(JsonRpcConnectionEvent::Disconnected {
|
||||
reason: Some(format!(
|
||||
"failed to read relay websocket frame from {reader_label}: {err}"
|
||||
)),
|
||||
})
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let writer_task = tokio::spawn(async move {
|
||||
let websocket_task = tokio::spawn(async move {
|
||||
let mut websocket = stream;
|
||||
let reader_label = connection_label;
|
||||
let reader_stream_id = stream_id.clone();
|
||||
let resume = RelayMessageFrame::resume(stream_id.clone());
|
||||
if websocket_writer
|
||||
if websocket
|
||||
.send(Message::Binary(encode_relay_message_frame(&resume).into()))
|
||||
.await
|
||||
.is_err()
|
||||
@@ -263,11 +168,6 @@ where
|
||||
return;
|
||||
}
|
||||
|
||||
let mut keepalive = tokio::time::interval_at(
|
||||
tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL,
|
||||
WEBSOCKET_KEEPALIVE_INTERVAL,
|
||||
);
|
||||
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
let mut next_seq = 0u32;
|
||||
loop {
|
||||
tokio::select! {
|
||||
@@ -284,7 +184,7 @@ where
|
||||
};
|
||||
let frame = RelayMessageFrame::data(stream_id.clone(), next_seq, payload);
|
||||
next_seq = next_seq.wrapping_add(1);
|
||||
if websocket_writer
|
||||
if websocket
|
||||
.send(Message::Binary(encode_relay_message_frame(&frame).into()))
|
||||
.await
|
||||
.is_err()
|
||||
@@ -293,10 +193,96 @@ where
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = keepalive.tick() => {
|
||||
if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() {
|
||||
let _ = disconnected_tx.send(true);
|
||||
break;
|
||||
incoming_message = websocket.next() => {
|
||||
match incoming_message {
|
||||
Some(Ok(Message::Binary(payload))) => {
|
||||
let frame = match decode_relay_message_frame(payload.as_ref()) {
|
||||
Ok(frame) => frame,
|
||||
Err(err) => {
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: format!(
|
||||
"failed to parse relay message frame from {reader_label}: {err}"
|
||||
),
|
||||
})
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if frame.stream_id != reader_stream_id {
|
||||
continue;
|
||||
}
|
||||
let kind = match frame.validate() {
|
||||
Ok(kind) => kind,
|
||||
Err(err) => {
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
match kind {
|
||||
RelayFrameBodyKind::Data => match frame.into_jsonrpc_message() {
|
||||
Ok(message) => {
|
||||
if incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Message(message))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: err.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
},
|
||||
RelayFrameBodyKind::Reset => {
|
||||
let _ = disconnected_tx.send(true);
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Disconnected {
|
||||
reason: frame.into_reset_reason(),
|
||||
})
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
RelayFrameBodyKind::Ack
|
||||
| RelayFrameBodyKind::Resume
|
||||
| RelayFrameBodyKind::Heartbeat => {}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => {
|
||||
let _ = disconnected_tx.send(true);
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Disconnected { reason: None })
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {}
|
||||
Some(Ok(Message::Text(_))) => {
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::MalformedMessage {
|
||||
reason: "relay exec-server transport expects binary protobuf frames"
|
||||
.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
let _ = disconnected_tx.send(true);
|
||||
let _ = incoming_tx
|
||||
.send(JsonRpcConnectionEvent::Disconnected {
|
||||
reason: Some(format!(
|
||||
"failed to read relay websocket frame from {reader_label}: {err}"
|
||||
)),
|
||||
})
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -307,7 +293,7 @@ where
|
||||
outgoing_tx,
|
||||
incoming_rx,
|
||||
disconnected_rx,
|
||||
task_handles: vec![reader_task, writer_task],
|
||||
task_handles: vec![websocket_task],
|
||||
transport: JsonRpcTransport::Plain,
|
||||
}
|
||||
}
|
||||
@@ -318,59 +304,42 @@ pub(crate) async fn run_multiplexed_executor<S>(
|
||||
) where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let (mut websocket_writer, mut websocket_reader) = stream.split();
|
||||
let mut websocket = stream;
|
||||
let (physical_outgoing_tx, mut physical_outgoing_rx) =
|
||||
mpsc::channel::<Vec<u8>>(CHANNEL_CAPACITY);
|
||||
let writer_task = tokio::spawn(async move {
|
||||
let mut keepalive = tokio::time::interval_at(
|
||||
tokio::time::Instant::now() + WEBSOCKET_KEEPALIVE_INTERVAL,
|
||||
WEBSOCKET_KEEPALIVE_INTERVAL,
|
||||
);
|
||||
keepalive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
tokio::select! {
|
||||
maybe_encoded = physical_outgoing_rx.recv() => {
|
||||
let Some(encoded) = maybe_encoded else {
|
||||
break;
|
||||
};
|
||||
if websocket_writer
|
||||
.send(Message::Binary(encoded.into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = keepalive.tick() => {
|
||||
if websocket_writer.send(Message::Ping(Vec::new().into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut streams: HashMap<String, VirtualStream> = HashMap::new();
|
||||
loop {
|
||||
let frame = match websocket_reader.next().await {
|
||||
Some(Ok(Message::Binary(payload))) => {
|
||||
match decode_relay_message_frame(payload.as_ref()) {
|
||||
Ok(frame) => frame,
|
||||
Err(err) => {
|
||||
warn!("dropping malformed relay message frame from harness: {err}");
|
||||
continue;
|
||||
}
|
||||
let frame = tokio::select! {
|
||||
maybe_encoded = physical_outgoing_rx.recv() => {
|
||||
let Some(encoded) = maybe_encoded else {
|
||||
break;
|
||||
};
|
||||
if websocket.send(Message::Binary(encoded.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => break,
|
||||
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue,
|
||||
Some(Ok(Message::Text(_))) => {
|
||||
warn!("dropping non-binary relay message frame from harness");
|
||||
continue;
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
debug!("multiplexed executor websocket read failed: {err}");
|
||||
break;
|
||||
incoming_message = websocket.next() => match incoming_message {
|
||||
Some(Ok(Message::Binary(payload))) => {
|
||||
match decode_relay_message_frame(payload.as_ref()) {
|
||||
Ok(frame) => frame,
|
||||
Err(err) => {
|
||||
warn!("dropping malformed relay message frame from harness: {err}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => break,
|
||||
Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue,
|
||||
Some(Ok(Message::Text(_))) => {
|
||||
warn!("dropping non-binary relay message frame from harness");
|
||||
continue;
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
debug!("multiplexed executor websocket read failed: {err}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -423,7 +392,6 @@ pub(crate) async fn run_multiplexed_executor<S>(
|
||||
stream.disconnect(/*reason*/ None).await;
|
||||
}
|
||||
drop(physical_outgoing_tx);
|
||||
let _ = writer_task.await;
|
||||
}
|
||||
|
||||
struct VirtualStream {
|
||||
@@ -492,8 +460,20 @@ fn spawn_virtual_stream(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCRequest;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use futures::Sink;
|
||||
use futures::Stream;
|
||||
use futures::channel::mpsc as futures_mpsc;
|
||||
use futures::task::AtomicWaker;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::time::timeout;
|
||||
use tokio_tungstenite::accept_async;
|
||||
@@ -502,40 +482,107 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
fn test_runtime_paths() -> anyhow::Result<crate::ExecServerRuntimePaths> {
|
||||
crate::ExecServerRuntimePaths::new(
|
||||
std::env::current_exe()?,
|
||||
/*codex_linux_sandbox_exe*/ None,
|
||||
)
|
||||
.map_err(anyhow::Error::from)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn multiplexed_executor_sends_keepalive_ping() -> anyhow::Result<()> {
|
||||
async fn harness_connection_receives_relay_data() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let executor_task = tokio::spawn(run_multiplexed_executor(
|
||||
client_websocket,
|
||||
ConnectionProcessor::new(test_runtime_paths()?),
|
||||
let mut connection =
|
||||
harness_connection_from_websocket(client_websocket, "test".to_string());
|
||||
let stream_id = read_resume_stream_id(&mut server_websocket).await?;
|
||||
let message = test_jsonrpc_message();
|
||||
|
||||
server_websocket
|
||||
.send(Message::Binary(
|
||||
encode_relay_message_frame(&RelayMessageFrame::data(
|
||||
stream_id,
|
||||
/*seq*/ 0,
|
||||
jsonrpc_payload(&message)?,
|
||||
))
|
||||
.into(),
|
||||
))
|
||||
.await?;
|
||||
assert!(matches!(
|
||||
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
|
||||
Some(JsonRpcConnectionEvent::Message(actual)) if actual == message
|
||||
));
|
||||
|
||||
read_keepalive_ping(&mut server_websocket).await?;
|
||||
|
||||
executor_task.abort();
|
||||
let _ = executor_task.await;
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn harness_connection_sends_keepalive_ping() -> anyhow::Result<()> {
|
||||
async fn harness_connection_reports_text_frames_as_malformed() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let connection = harness_connection_from_websocket(client_websocket, "test".to_string());
|
||||
let mut connection =
|
||||
harness_connection_from_websocket(client_websocket, "test".to_string());
|
||||
|
||||
read_keepalive_ping(&mut server_websocket).await?;
|
||||
read_resume_stream_id(&mut server_websocket).await?;
|
||||
server_websocket.send(Message::Text("nope".into())).await?;
|
||||
assert!(matches!(
|
||||
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
|
||||
Some(JsonRpcConnectionEvent::MalformedMessage { reason })
|
||||
if reason == "relay exec-server transport expects binary protobuf frames"
|
||||
));
|
||||
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn harness_connection_reports_server_close() -> anyhow::Result<()> {
|
||||
let (client_websocket, mut server_websocket) = websocket_pair().await?;
|
||||
let mut connection =
|
||||
harness_connection_from_websocket(client_websocket, "test".to_string());
|
||||
|
||||
read_resume_stream_id(&mut server_websocket).await?;
|
||||
server_websocket.close(None).await?;
|
||||
assert!(matches!(
|
||||
timeout(Duration::from_secs(1), connection.incoming_rx.recv()).await?,
|
||||
Some(JsonRpcConnectionEvent::Disconnected { reason: None })
|
||||
));
|
||||
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn harness_connection_keeps_outbound_frame_while_send_is_backpressured()
|
||||
-> anyhow::Result<()> {
|
||||
let (websocket, control, mut outbound_rx) =
|
||||
ControlledWebSocket::new(/*write_ready*/ true);
|
||||
let mut connection = harness_connection_from_websocket(websocket, "test".to_string());
|
||||
let Message::Binary(resume_payload) = timeout(Duration::from_secs(1), outbound_rx.next())
|
||||
.await?
|
||||
.expect("resume frame")
|
||||
else {
|
||||
anyhow::bail!("expected relay resume frame");
|
||||
};
|
||||
let stream_id = decode_relay_message_frame(resume_payload.as_ref())?.stream_id;
|
||||
let message = test_jsonrpc_message();
|
||||
|
||||
control.set_write_blocked();
|
||||
connection.outgoing_tx.send(message.clone()).await?;
|
||||
control.wait_for_blocked_write().await?;
|
||||
control.send_inbound(Message::Pong(b"check".to_vec().into()))?;
|
||||
assert!(
|
||||
timeout(Duration::from_millis(50), connection.incoming_rx.recv())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
control.set_write_ready();
|
||||
let Message::Binary(data_payload) = timeout(Duration::from_secs(1), outbound_rx.next())
|
||||
.await?
|
||||
.expect("data frame")
|
||||
else {
|
||||
anyhow::bail!("expected relay data frame");
|
||||
};
|
||||
let frame = decode_relay_message_frame(data_payload.as_ref())?;
|
||||
assert_eq!(frame.stream_id, stream_id);
|
||||
assert_eq!(frame.into_jsonrpc_message()?, message);
|
||||
drop(connection);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn websocket_pair() -> anyhow::Result<(
|
||||
WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
|
||||
WebSocketStream<tokio::net::TcpStream>,
|
||||
@@ -551,18 +598,155 @@ mod tests {
|
||||
Ok((client_websocket, server_websocket))
|
||||
}
|
||||
|
||||
async fn read_keepalive_ping(
|
||||
async fn read_resume_stream_id(
|
||||
websocket: &mut WebSocketStream<tokio::net::TcpStream>,
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let Some(message) = timeout(Duration::from_secs(1), websocket.next()).await? else {
|
||||
anyhow::bail!("websocket closed before keepalive ping");
|
||||
};
|
||||
match message? {
|
||||
Message::Ping(_) => return Ok(()),
|
||||
Message::Binary(_) | Message::Text(_) | Message::Pong(_) | Message::Frame(_) => {}
|
||||
Message::Close(_) => anyhow::bail!("websocket closed before keepalive ping"),
|
||||
) -> anyhow::Result<String> {
|
||||
let message = timeout(Duration::from_secs(1), websocket.next())
|
||||
.await?
|
||||
.expect("websocket should stay open")?;
|
||||
let Message::Binary(payload) = message else {
|
||||
anyhow::bail!("expected relay resume frame, got {message:?}");
|
||||
};
|
||||
let frame = decode_relay_message_frame(payload.as_ref())?;
|
||||
assert_eq!(frame.validate()?, RelayFrameBodyKind::Resume);
|
||||
Ok(frame.stream_id)
|
||||
}
|
||||
|
||||
fn test_jsonrpc_message() -> JSONRPCMessage {
|
||||
JSONRPCMessage::Request(JSONRPCRequest {
|
||||
id: RequestId::Integer(1),
|
||||
method: "test".to_string(),
|
||||
params: None,
|
||||
trace: None,
|
||||
})
|
||||
}
|
||||
|
||||
struct ControlledWebSocket {
|
||||
inbound_rx: futures_mpsc::UnboundedReceiver<Result<Message, std::convert::Infallible>>,
|
||||
outbound_tx: futures_mpsc::UnboundedSender<Message>,
|
||||
write_ready: Arc<AtomicBool>,
|
||||
write_blocked: Arc<AtomicBool>,
|
||||
write_blocked_waker: Arc<AtomicWaker>,
|
||||
write_waker: Arc<AtomicWaker>,
|
||||
}
|
||||
|
||||
struct ControlledWebSocketHandle {
|
||||
inbound_tx: futures_mpsc::UnboundedSender<Result<Message, std::convert::Infallible>>,
|
||||
write_ready: Arc<AtomicBool>,
|
||||
write_blocked: Arc<AtomicBool>,
|
||||
write_blocked_waker: Arc<AtomicWaker>,
|
||||
write_waker: Arc<AtomicWaker>,
|
||||
}
|
||||
|
||||
impl ControlledWebSocket {
|
||||
fn new(
|
||||
write_ready: bool,
|
||||
) -> (
|
||||
Self,
|
||||
ControlledWebSocketHandle,
|
||||
futures_mpsc::UnboundedReceiver<Message>,
|
||||
) {
|
||||
let (inbound_tx, inbound_rx) = futures_mpsc::unbounded();
|
||||
let (outbound_tx, outbound_rx) = futures_mpsc::unbounded();
|
||||
let write_ready = Arc::new(AtomicBool::new(write_ready));
|
||||
let write_blocked = Arc::new(AtomicBool::new(false));
|
||||
let write_blocked_waker = Arc::new(AtomicWaker::new());
|
||||
let write_waker = Arc::new(AtomicWaker::new());
|
||||
(
|
||||
Self {
|
||||
inbound_rx,
|
||||
outbound_tx,
|
||||
write_ready: Arc::clone(&write_ready),
|
||||
write_blocked: Arc::clone(&write_blocked),
|
||||
write_blocked_waker: Arc::clone(&write_blocked_waker),
|
||||
write_waker: Arc::clone(&write_waker),
|
||||
},
|
||||
ControlledWebSocketHandle {
|
||||
inbound_tx,
|
||||
write_ready,
|
||||
write_blocked,
|
||||
write_blocked_waker,
|
||||
write_waker,
|
||||
},
|
||||
outbound_rx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ControlledWebSocketHandle {
|
||||
fn send_inbound(&self, message: Message) -> anyhow::Result<()> {
|
||||
self.inbound_tx
|
||||
.unbounded_send(Ok(message))
|
||||
.map_err(anyhow::Error::from)
|
||||
}
|
||||
|
||||
fn set_write_blocked(&self) {
|
||||
self.write_ready.store(false, Ordering::Release);
|
||||
}
|
||||
|
||||
fn set_write_ready(&self) {
|
||||
self.write_ready.store(true, Ordering::Release);
|
||||
self.write_waker.wake();
|
||||
}
|
||||
|
||||
async fn wait_for_blocked_write(&self) -> anyhow::Result<()> {
|
||||
timeout(
|
||||
Duration::from_secs(1),
|
||||
futures::future::poll_fn(|cx| {
|
||||
if self.write_blocked.load(Ordering::Acquire) {
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
self.write_blocked_waker.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<Message> for ControlledWebSocket {
|
||||
type Error = std::convert::Infallible;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
if self.write_ready.load(Ordering::Acquire) {
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
self.write_blocked.store(true, Ordering::Release);
|
||||
self.write_blocked_waker.wake();
|
||||
self.write_waker.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
|
||||
self.outbound_tx
|
||||
.unbounded_send(item)
|
||||
.expect("test outbound receiver should stay open");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for ControlledWebSocket {
|
||||
type Item = Result<Message, std::convert::Infallible>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
Pin::new(&mut self.inbound_rx).poll_next(cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user