mirror of
https://github.com/pchuan98/codex.git
synced 2026-07-01 00:31:56 +08:00
refactor: use cloneable async channels for shared receivers (#18398)
This is the first mechanical cleanup in a stack whose higher-level goal is to enable Clippy coverage for async guards held across `.await` points. The follow-up commits enable Clippy's [`await_holding_lock`](https://rust-lang.github.io/rust-clippy/master/index.html#await_holding_lock) lint and the configurable [`await_holding_invalid_type`](https://rust-lang.github.io/rust-clippy/master/index.html#await_holding_invalid_type) lint for Tokio guard types. This PR handles the cases where the underlying issue is not protected shared mutable state, but a `tokio::sync::mpsc::UnboundedReceiver` wrapped in `Arc<Mutex<_>>` so cloned owners can call `recv().await`. Using a mutex for that shape forces the receiver lock guard to live across `.await`. Switching these paths to `async-channel` gives us cloneable `Receiver`s, so each owner can hold a receiver handle directly and await messages without an async mutex guard. ## What changed - In `codex-rs/code-mode`, replace the turn-message `mpsc::UnboundedSender`/`UnboundedReceiver` plus `Arc<Mutex<Receiver>>` with `async_channel::Sender`/`Receiver`. - In `codex-rs/codex-api`, replace the realtime websocket event receiver with an `async_channel::Receiver`, allowing `RealtimeWebsocketEvents` clones to receive without locking. - Add `async-channel` as a dependency for `codex-code-mode` and `codex-api`, and update `Cargo.lock`. ## Verification - The split stack was verified at the final lint-enabling head with `just clippy`.
This commit is contained in:
committed by
GitHub
Unverified
parent
0e111e08d0
commit
c9c4caafd8
Generated
+2
@@ -1409,6 +1409,7 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_matches",
|
||||
"async-channel",
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
@@ -1842,6 +1843,7 @@ dependencies = [
|
||||
name = "codex-code-mode"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"async-channel",
|
||||
"async-trait",
|
||||
"codex-protocol",
|
||||
"deno_core_icudata",
|
||||
|
||||
@@ -13,6 +13,7 @@ path = "src/lib.rs"
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-channel = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
deno_core_icudata = { workspace = true }
|
||||
|
||||
@@ -44,8 +44,8 @@ struct SessionHandle {
|
||||
struct Inner {
|
||||
stored_values: Mutex<HashMap<String, JsonValue>>,
|
||||
sessions: Mutex<HashMap<String, SessionHandle>>,
|
||||
turn_message_tx: mpsc::UnboundedSender<TurnMessage>,
|
||||
turn_message_rx: Arc<Mutex<mpsc::UnboundedReceiver<TurnMessage>>>,
|
||||
turn_message_tx: async_channel::Sender<TurnMessage>,
|
||||
turn_message_rx: async_channel::Receiver<TurnMessage>,
|
||||
next_cell_id: AtomicU64,
|
||||
}
|
||||
|
||||
@@ -55,14 +55,14 @@ pub struct CodeModeService {
|
||||
|
||||
impl CodeModeService {
|
||||
pub fn new() -> Self {
|
||||
let (turn_message_tx, turn_message_rx) = mpsc::unbounded_channel();
|
||||
let (turn_message_tx, turn_message_rx) = async_channel::unbounded();
|
||||
|
||||
Self {
|
||||
inner: Arc::new(Inner {
|
||||
stored_values: Mutex::new(HashMap::new()),
|
||||
sessions: Mutex::new(HashMap::new()),
|
||||
turn_message_tx,
|
||||
turn_message_rx: Arc::new(Mutex::new(turn_message_rx)),
|
||||
turn_message_rx,
|
||||
next_cell_id: AtomicU64::new(1),
|
||||
}),
|
||||
}
|
||||
@@ -146,16 +146,13 @@ impl CodeModeService {
|
||||
pub fn start_turn_worker(&self, host: Arc<dyn CodeModeTurnHost>) -> CodeModeTurnWorker {
|
||||
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
|
||||
let inner = Arc::clone(&self.inner);
|
||||
let turn_message_rx = Arc::clone(&self.inner.turn_message_rx);
|
||||
let turn_message_rx = self.inner.turn_message_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let next_message = tokio::select! {
|
||||
_ = &mut shutdown_rx => break,
|
||||
message = async {
|
||||
let mut turn_message_rx = turn_message_rx.lock().await;
|
||||
turn_message_rx.recv().await
|
||||
} => message,
|
||||
message = turn_message_rx.recv() => message.ok(),
|
||||
};
|
||||
let Some(next_message) = next_message else {
|
||||
break;
|
||||
@@ -361,7 +358,7 @@ async fn run_session_control(
|
||||
cell_id: cell_id.clone(),
|
||||
call_id,
|
||||
text,
|
||||
});
|
||||
}).await;
|
||||
}
|
||||
RuntimeEvent::ToolCall { id, name, input } => {
|
||||
let _ = inner.turn_message_tx.send(TurnMessage::ToolCall {
|
||||
@@ -369,7 +366,7 @@ async fn run_session_control(
|
||||
id,
|
||||
name,
|
||||
input,
|
||||
});
|
||||
}).await;
|
||||
}
|
||||
RuntimeEvent::Result {
|
||||
stored_values,
|
||||
@@ -500,12 +497,12 @@ mod tests {
|
||||
}
|
||||
|
||||
fn test_inner() -> Arc<Inner> {
|
||||
let (turn_message_tx, turn_message_rx) = mpsc::unbounded_channel();
|
||||
let (turn_message_tx, turn_message_rx) = async_channel::unbounded();
|
||||
Arc::new(Inner {
|
||||
stored_values: Mutex::new(HashMap::new()),
|
||||
sessions: Mutex::new(HashMap::new()),
|
||||
turn_message_tx,
|
||||
turn_message_rx: Arc::new(Mutex::new(turn_message_rx)),
|
||||
turn_message_rx,
|
||||
next_cell_id: AtomicU64::new(1),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-channel = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
|
||||
@@ -65,9 +65,9 @@ enum WsCommand {
|
||||
impl WsStream {
|
||||
fn new(
|
||||
inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
) -> (Self, mpsc::UnboundedReceiver<Result<Message, WsError>>) {
|
||||
) -> (Self, async_channel::Receiver<Result<Message, WsError>>) {
|
||||
let (tx_command, mut rx_command) = mpsc::channel::<WsCommand>(32);
|
||||
let (tx_message, rx_message) = mpsc::unbounded_channel::<Result<Message, WsError>>();
|
||||
let (tx_message, rx_message) = async_channel::unbounded::<Result<Message, WsError>>();
|
||||
|
||||
let pump_task = tokio::spawn(async move {
|
||||
let mut inner = inner;
|
||||
@@ -110,7 +110,7 @@ impl WsStream {
|
||||
trace!(payload_len = payload.len(), "realtime websocket received ping");
|
||||
if let Err(err) = inner.send(Message::Pong(payload)).await {
|
||||
error!("realtime websocket failed to send pong: {err}");
|
||||
let _ = tx_message.send(Err(err));
|
||||
let _ = tx_message.send(Err(err)).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -138,7 +138,7 @@ impl WsStream {
|
||||
}
|
||||
Message::Ping(_) | Message::Pong(_) => {}
|
||||
}
|
||||
if tx_message.send(Ok(message)).is_err() {
|
||||
if tx_message.send(Ok(message)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
if is_close {
|
||||
@@ -147,7 +147,7 @@ impl WsStream {
|
||||
}
|
||||
Err(err) => {
|
||||
error!("realtime websocket receive failed: {err}");
|
||||
let _ = tx_message.send(Err(err));
|
||||
let _ = tx_message.send(Err(err)).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -208,7 +208,7 @@ pub struct RealtimeWebsocketWriter {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RealtimeWebsocketEvents {
|
||||
rx_message: Arc<Mutex<mpsc::UnboundedReceiver<Result<Message, WsError>>>>,
|
||||
rx_message: async_channel::Receiver<Result<Message, WsError>>,
|
||||
active_transcript: Arc<Mutex<ActiveTranscriptState>>,
|
||||
event_parser: RealtimeEventParser,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
@@ -256,7 +256,7 @@ impl RealtimeWebsocketConnection {
|
||||
|
||||
fn new(
|
||||
stream: WsStream,
|
||||
rx_message: mpsc::UnboundedReceiver<Result<Message, WsError>>,
|
||||
rx_message: async_channel::Receiver<Result<Message, WsError>>,
|
||||
event_parser: RealtimeEventParser,
|
||||
) -> Self {
|
||||
let stream = Arc::new(stream);
|
||||
@@ -268,7 +268,7 @@ impl RealtimeWebsocketConnection {
|
||||
event_parser,
|
||||
},
|
||||
events: RealtimeWebsocketEvents {
|
||||
rx_message: Arc::new(Mutex::new(rx_message)),
|
||||
rx_message,
|
||||
active_transcript: Arc::new(Mutex::new(ActiveTranscriptState::default())),
|
||||
event_parser,
|
||||
is_closed,
|
||||
@@ -369,16 +369,16 @@ impl RealtimeWebsocketEvents {
|
||||
}
|
||||
|
||||
loop {
|
||||
let msg = match self.rx_message.lock().await.recv().await {
|
||||
Some(Ok(msg)) => msg,
|
||||
Some(Err(err)) => {
|
||||
let msg = match self.rx_message.recv().await {
|
||||
Ok(Ok(msg)) => msg,
|
||||
Ok(Err(err)) => {
|
||||
self.is_closed.store(true, Ordering::SeqCst);
|
||||
error!("realtime websocket read failed: {err}");
|
||||
return Err(ApiError::Stream(format!(
|
||||
"failed to read websocket message: {err}"
|
||||
)));
|
||||
}
|
||||
None => {
|
||||
Err(_) => {
|
||||
self.is_closed.store(true, Ordering::SeqCst);
|
||||
info!("realtime websocket event stream ended");
|
||||
return Ok(None);
|
||||
|
||||
Reference in New Issue
Block a user