diff --git a/codex-rs/app-server/src/connection_cleanup.rs b/codex-rs/app-server/src/connection_cleanup.rs new file mode 100644 index 000000000..529020fe3 --- /dev/null +++ b/codex-rs/app-server/src/connection_cleanup.rs @@ -0,0 +1,49 @@ +use std::future::Future; +use std::future::pending; + +use tokio::task::JoinError; +use tokio::task::JoinSet; +use tracing::warn; + +pub(crate) struct ConnectionCleanupTasks { + tasks: JoinSet<()>, +} + +impl ConnectionCleanupTasks { + pub(crate) fn new() -> Self { + Self { + tasks: JoinSet::new(), + } + } + + pub(crate) fn spawn(&mut self, future: impl Future + Send + 'static) { + self.tasks.spawn(future); + } + + pub(crate) async fn reap_next(&mut self) { + if self.tasks.is_empty() { + pending::<()>().await; + } + if let Some(result) = self.tasks.join_next().await { + log_cleanup_result(result); + } + } + + pub(crate) async fn drain(&mut self) { + while let Some(result) = self.tasks.join_next().await { + log_cleanup_result(result); + } + } + + pub(crate) fn abort(&mut self) { + self.tasks.abort_all(); + } +} + +fn log_cleanup_result(result: Result<(), JoinError>) { + if let Err(err) = result + && !err.is_cancelled() + { + warn!("connection cleanup task failed: {err}"); + } +} diff --git a/codex-rs/app-server/src/connection_rpc_gate.rs b/codex-rs/app-server/src/connection_rpc_gate.rs index 12fed79b3..fb2aedd35 100644 --- a/codex-rs/app-server/src/connection_rpc_gate.rs +++ b/codex-rs/app-server/src/connection_rpc_gate.rs @@ -38,12 +38,14 @@ impl ConnectionRpcGate { drop(token); } + pub(crate) async fn close(&self) { + let mut accepting = self.accepting.lock().await; + *accepting = false; + self.tasks.close(); + } + pub(crate) async fn shutdown(&self) { - { - let mut accepting = self.accepting.lock().await; - *accepting = false; - self.tasks.close(); - } + self.close().await; self.tasks.wait().await; } @@ -90,9 +92,9 @@ mod tests { } #[tokio::test] - async fn run_drops_future_without_polling_after_shutdown() { + async fn run_drops_future_without_polling_after_close() { let gate = ConnectionRpcGate::new(); - gate.shutdown().await; + gate.close().await; let polled = Arc::new(AtomicBool::new(/*v*/ false)); let polled_clone = Arc::clone(&polled); @@ -105,6 +107,33 @@ mod tests { assert!(!gate.is_accepting().await); } + #[tokio::test] + async fn close_returns_while_started_run_remains_active() { + let gate = Arc::new(ConnectionRpcGate::new()); + let (started_tx, started_rx) = oneshot::channel(); + let (finish_tx, finish_rx) = oneshot::channel(); + let gate_for_run = Arc::clone(&gate); + let run_task = tokio::spawn(async move { + gate_for_run + .run(async move { + started_tx.send(()).expect("receiver should be open"); + let _ = finish_rx.await; + }) + .await; + }); + + started_rx.await.expect("run should start"); + gate.close().await; + assert!(!gate.is_accepting().await); + assert_eq!(gate.inflight_count(), 1); + + finish_tx + .send(()) + .expect("running future should be waiting"); + run_task.await.expect("run task should complete"); + gate.shutdown().await; + } + #[tokio::test] async fn shutdown_waits_for_started_run_to_finish() { let gate = Arc::new(ConnectionRpcGate::new()); diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 0689a79f0..b6f0b8c85 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -20,6 +20,7 @@ use std::sync::atomic::AtomicBool; use crate::analytics_utils::analytics_events_client_from_config; use crate::config_manager::ConfigManager; +use crate::connection_cleanup::ConnectionCleanupTasks; use crate::message_processor::MessageProcessor; use crate::message_processor::MessageProcessorArgs; use crate::outgoing_message::ConnectionId; @@ -81,6 +82,7 @@ mod command_exec; mod config; mod config_manager; mod config_manager_service; +mod connection_cleanup; mod connection_rpc_gate; mod dynamic_tools; mod error_code; @@ -819,6 +821,7 @@ pub async fn run_main_with_transport_options( let mut thread_created_rx = processor.thread_created_receiver(); let mut running_turn_count_rx = processor.subscribe_running_assistant_turn_count(); let mut connections = HashMap::::new(); + let mut connection_cleanup_tasks = ConnectionCleanupTasks::new(); let mut remote_control_status_rx = remote_control_handle.status_receiver(); let mut remote_control_status = remote_control_status_rx.borrow().clone(); let transport_shutdown_token = transport_shutdown_token.clone(); @@ -906,14 +909,20 @@ pub async fn run_main_with_transport_options( let Some(connection_state) = connections.remove(&connection_id) else { continue; }; - if outbound_control_tx + connection_state.session.rpc_gate.close().await; + let outbound_closed = outbound_control_tx .send(OutboundControlEvent::Closed { connection_id }) .await - .is_err() - { + .is_ok(); + let processor = Arc::clone(&processor); + connection_cleanup_tasks.spawn(async move { + processor + .connection_closed(connection_id, &connection_state.session) + .await; + }); + if !outbound_closed { break; } - processor.connection_closed(connection_id, &connection_state.session).await; if shutdown_when_no_connections && connections.is_empty() { break; } @@ -1010,6 +1019,7 @@ pub async fn run_main_with_transport_options( } } } + _ = connection_cleanup_tasks.reap_next() => {} changed = remote_control_status_rx.changed() => { if changed.is_err() { continue; @@ -1062,8 +1072,11 @@ pub async fn run_main_with_transport_options( .map(|connection_state| connection_state.session.rpc_gate.shutdown()), ) .await; + connection_cleanup_tasks.drain().await; processor.drain_background_tasks().await; processor.shutdown_threads().await; + } else { + connection_cleanup_tasks.abort(); } info!("processor task exited (channel closed)"); } diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index b7e0ff258..65c57bf26 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -91,6 +91,7 @@ use tokio_util::sync::CancellationToken; use tracing::Instrument; const EXTERNAL_AUTH_REFRESH_TIMEOUT: Duration = Duration::from_secs(10); +const CONNECTION_RPC_DRAIN_TIMEOUT: Duration = Duration::from_secs(/*secs*/ 30); #[derive(Clone)] struct ExternalAuthRefreshBridge { @@ -723,7 +724,19 @@ impl MessageProcessor { connection_id: ConnectionId, session_state: &ConnectionSessionState, ) { - session_state.rpc_gate.shutdown().await; + if timeout( + CONNECTION_RPC_DRAIN_TIMEOUT, + session_state.rpc_gate.shutdown(), + ) + .await + .is_err() + { + tracing::warn!( + ?connection_id, + timeout_seconds = CONNECTION_RPC_DRAIN_TIMEOUT.as_secs(), + "timed out waiting for connection RPCs to drain" + ); + } self.outgoing.connection_closed(connection_id).await; self.fs_processor.connection_closed(connection_id).await; self.command_exec_processor diff --git a/codex-rs/app-server/src/request_serialization.rs b/codex-rs/app-server/src/request_serialization.rs index 0dd167b74..77ecfc8f5 100644 --- a/codex-rs/app-server/src/request_serialization.rs +++ b/codex-rs/app-server/src/request_serialization.rs @@ -311,7 +311,7 @@ mod tests { let key = RequestSerializationQueueKey::Global("test"); let live_gate = gate(); let closed_gate = gate(); - closed_gate.shutdown().await; + closed_gate.close().await; let (tx, mut rx) = mpsc::unbounded_channel(); let (blocked_tx, blocked_rx) = oneshot::channel::<()>();