Resume exec-server sessions after disconnect (#28512)

Supersedes #28288 (closed).

## Why

A short WebSocket interruption currently ends every client-side process
handle, even though exec-server keeps the server session and its
processes alive for a short time.

This is especially visible for executor-backed stdio MCP servers: a
temporary connection loss becomes a permanent `Transport closed` error.
The server already has the information needed to resume the session, but
the client opens a fresh session instead of using it.

This change reconnects below the process and MCP layers. Existing
process handles stay valid, missed output is recovered, and the same
server-side processes continue running.

## State machine

One logical `ExecServerClient` stays alive while its underlying RPC
connection changes generations.

```text
                         transport closes
       +------------------------------------------------+
       |                                                v
+-------------+                                  +-------------+
|  Connected  |                                  | Recovering  |
+-------------+                                  +-------------+
       ^                                                |
       | session resumed, processes caught up           | retryable error
       +------------------------------------------------+ loops until deadline
                                                        |
                                                        | deadline or permanent error
                                                        v
                                                  +-------------+
                                                  |   Failed    |
                                                  +-------------+
```

### `Connected`

- New RPC calls use the current connection.
- Process notifications are published in sequence order.
- A disconnect only starts recovery if it came from the current
connection generation. Late events from older generations cannot replace
the active connection.

### `Recovering`

- New calls wait instead of choosing a half-connected RPC client.
- Existing process handles, wake subscriptions, and event subscriptions
stay open.
- Streaming HTTP response bodies fail immediately because their byte
streams cannot be resumed safely.
- Recovery first waits for process starts that were already in flight. A
start whose result became ambiguous is cleaned up after reconnection
instead of being silently adopted.
- The client reconnects with the learned `session_id`. The server may
briefly report that the old connection is still attached, so that error
is retried until the detach finishes.
- The notification consumer starts before the resume handshake
completes. This prevents a busy process from filling the notification
queue and blocking the initialize response.
- Before installing the new connection, the client catches up every
recoverable process with `process/read`.

### `Failed`

- Recovery stops after 25 seconds or after a permanent error.
- Waiting calls are released with one stable disconnect error.
- Existing process sessions receive a terminal failure instead of
waiting forever.

## Recovering process events

Output, exit, and close events share one sequence. During normal
operation, the client buffers early events until every lower sequence
has been published.

After reconnection, the client reads each process starting after its
last published sequence:

1. Retained output chunks are inserted by sequence number.
2. Exit and close state are reconstructed in their sequence positions.
3. Events already received as live notifications are ignored as
duplicates.
4. Newly contiguous events are published in order.
5. If the server no longer retains enough output to fill a sequence gap,
only that process is terminated and failed. The recovered connection
remains usable for other processes.

The server reports its full next event sequence for unbounded reads,
including exit and close events. Closed processes remain readable for
the same 30-second window used to retain detached sessions.

## Other details

- Detached server sessions are retained for 30 seconds, leaving margin
around the client's 25-second recovery deadline.
- Session attach and detach update the active notification sender under
the same attachment lock, so an old connection cannot clear a newly
attached sender.
- A dedicated error code distinguishes the temporary "session is still
attached" race from permanent initialization errors.
- Process starts are identity-checked on both client and server. Cleanup
from an older start cannot remove a newer process that reused the same
ID.
- Mutating requests that were already in flight when the transport
closed are not replayed, because the client cannot know whether the
server applied them. Requests started after recovery is known wait for
the replacement connection.
- We assume the server/client version stays in sync (on the before/after
this PR)

## User impact

Long-running commands and stdio MCP servers can survive a temporary
exec-server WebSocket interruption without changing process IDs or
losing output produced during the outage.
This commit is contained in:
jif
2026-06-17 09:20:39 +01:00
committed by GitHub
Unverified
parent 1315198853
commit cf17e1bc20
11 changed files with 1366 additions and 364 deletions
+411 -255
View File
@@ -3,7 +3,9 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::sync::OnceLock;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use std::time::Duration;
use arc_swap::ArcSwap;
@@ -25,6 +27,7 @@ use crate::client_api::ExecServerTransportParams;
use crate::client_api::HttpClient;
use crate::client_api::RemoteExecServerConnectArgs;
use crate::client_api::StdioExecServerConnectArgs;
use crate::client_transport::ExecServerReconnectStrategy;
use crate::connection::JsonRpcConnection;
use crate::process::ExecProcessEvent;
use crate::process::ExecProcessEventLog;
@@ -95,9 +98,10 @@ use crate::protocol::WriteParams;
use crate::protocol::WriteResponse;
use crate::rpc::RpcCallError;
use crate::rpc::RpcClient;
use crate::rpc::RpcClientEvent;
pub(crate) mod http_client;
#[path = "client_recovery.rs"]
mod recovery;
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const INITIALIZE_TIMEOUT: Duration = Duration::from_secs(10);
@@ -150,16 +154,19 @@ pub(crate) struct SessionState {
wake_tx: watch::Sender<u64>,
events: ExecProcessEventLog,
ordered_events: StdMutex<OrderedSessionEvents>,
failure: Mutex<Option<String>>,
recoverable: AtomicBool,
}
#[derive(Default)]
struct OrderedSessionEvents {
last_published_seq: u64,
exit_published: bool,
closed_published: bool,
// Server-side output, exit, and closed notifications are emitted by
// different tasks and can reach the client out of order. Keep future events
// here until all lower sequence numbers have been published.
pending: BTreeMap<u64, ExecProcessEvent>,
failure: Option<String>,
}
#[derive(Clone)]
@@ -170,7 +177,8 @@ pub(crate) struct Session {
}
struct Inner {
client: RpcClient,
connection: StdMutex<ConnectionState>,
connection_changed: watch::Sender<()>,
// The remote transport delivers one shared notification stream for every
// process on the connection. Keep a local process_id -> session registry so
// we can turn those connection-global notifications into process wakeups
@@ -179,11 +187,7 @@ struct Inner {
// ArcSwap makes reads cheap on the hot notification path, but writes still
// need serialization so concurrent register/remove operations do not
// overwrite each other's copy-on-write updates.
sessions_write_lock: Mutex<()>,
// Once the transport closes, every environment operation should fail quickly
// with the same canonical message. This client never reconnects, so the
// latch only moves from unset to set once.
disconnected: OnceLock<String>,
sessions_write_lock: StdMutex<()>,
// Streaming HTTP responses are keyed by a client-generated request id
// because they share the same connection-global notification channel as
// process output. Keep the routing table local to the client so higher
@@ -192,14 +196,19 @@ struct Inner {
http_body_stream_failures: ArcSwap<HashMap<String, String>>,
http_body_streams_write_lock: Mutex<()>,
http_body_stream_next_id: AtomicU64,
session_id: std::sync::RwLock<Option<String>>,
reader_task: tokio::task::JoinHandle<()>,
session_id: OnceLock<String>,
reconnect_strategy: Option<ExecServerReconnectStrategy>,
}
impl Drop for Inner {
fn drop(&mut self) {
self.reader_task.abort();
}
struct ConnectionState {
status: ConnectionStatus,
active_process_starts: usize,
}
enum ConnectionStatus {
Connected(Arc<RpcClient>),
Recovering,
Failed(String),
}
#[derive(Clone)]
@@ -207,6 +216,16 @@ pub struct ExecServerClient {
inner: Arc<Inner>,
}
struct ActiveProcessStart {
inner: Arc<Inner>,
}
impl Drop for ActiveProcessStart {
fn drop(&mut self) {
self.inner.finish_process_start();
}
}
#[derive(Clone)]
pub(crate) struct LazyRemoteExecServerClient {
transport_params: ExecServerTransportParams,
@@ -339,6 +358,15 @@ impl ExecServerClient {
pub async fn initialize(
&self,
options: ExecServerClientConnectOptions,
) -> Result<InitializeResponse, ExecServerError> {
let rpc_client = self.inner.rpc_client().await?;
self.initialize_rpc(&rpc_client, options).await
}
async fn initialize_rpc(
&self,
rpc_client: &RpcClient,
options: ExecServerClientConnectOptions,
) -> Result<InitializeResponse, ExecServerError> {
let ExecServerClientConnectOptions {
client_name,
@@ -347,9 +375,7 @@ impl ExecServerClient {
} = options;
timeout(initialize_timeout, async {
let response: InitializeResponse = self
.inner
.client
let response: InitializeResponse = rpc_client
.call(
INITIALIZE_METHOD,
&InitializeParams {
@@ -358,15 +384,19 @@ impl ExecServerClient {
},
)
.await?;
{
let mut session_id = self
.inner
.session_id
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*session_id = Some(response.session_id.clone());
let session_id = self
.inner
.session_id
.get_or_init(|| response.session_id.clone());
if session_id != &response.session_id {
return Err(ExecServerError::Protocol(format!(
"exec-server initialized an unexpected session {}",
response.session_id
)));
}
self.notify_initialized().await?;
rpc_client
.notify(INITIALIZED_METHOD, &serde_json::json!({}))
.await?;
Ok(response)
})
.await
@@ -503,14 +533,72 @@ impl ExecServerClient {
self.call(FS_COPY_METHOD, &params).await
}
pub(crate) async fn start_process(
&self,
params: ExecParams,
) -> Result<Session, ExecServerError> {
loop {
let rpc_client = self.inner.rpc_client().await?;
if !self.inner.begin_process_start(&rpc_client) {
continue;
}
let process_id = params.process_id.clone();
let state = Arc::new(SessionState::new(/*recoverable*/ false));
if let Err(error) = self.inner.insert_session(&process_id, Arc::clone(&state)) {
self.inner.finish_process_start();
return Err(error);
}
let active_start = ActiveProcessStart {
inner: Arc::clone(&self.inner),
};
let client = self.clone();
let (result_tx, result_rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let _active_start = active_start;
match client
.call_rpc::<_, ExecResponse>(&rpc_client, EXEC_METHOD, &params)
.await
{
Ok(_) => {
state.recoverable.store(true, Ordering::Release);
let session = Session {
client: client.clone(),
process_id: process_id.clone(),
state: Arc::clone(&state),
};
if result_tx.send(Ok(session)).is_err() {
state.recoverable.store(false, Ordering::Release);
tokio::spawn(async move {
cleanup_process_start(&client, &process_id, &state).await;
});
}
}
Err(error) => {
if is_transport_closed_error(&error) {
tokio::spawn(async move {
cleanup_process_start(&client, &process_id, &state).await;
});
} else {
client.inner.remove_session_if(&process_id, &state);
}
let _ = result_tx.send(Err(error));
}
}
});
return result_rx.await.map_err(|_| {
ExecServerError::Protocol("process start task stopped unexpectedly".to_string())
})?;
}
}
#[cfg(test)]
pub(crate) async fn register_session(
&self,
process_id: &ProcessId,
) -> Result<Session, ExecServerError> {
let state = Arc::new(SessionState::new());
self.inner
.insert_session(process_id, Arc::clone(&state))
.await?;
let state = Arc::new(SessionState::new(/*recoverable*/ true));
self.inner.insert_session(process_id, Arc::clone(&state))?;
Ok(Session {
client: self.clone(),
process_id: process_id.clone(),
@@ -518,84 +606,52 @@ impl ExecServerClient {
})
}
pub(crate) async fn unregister_session(&self, process_id: &ProcessId) {
self.inner.remove_session(process_id).await;
}
pub fn session_id(&self) -> Option<String> {
self.inner
.session_id
.read()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.clone()
self.inner.session_id.get().cloned()
}
fn is_disconnected(&self) -> bool {
self.inner.disconnected.get().is_some() || self.inner.client.is_disconnected()
self.inner.is_failed()
}
pub(crate) async fn connect(
connection: JsonRpcConnection,
options: ExecServerClientConnectOptions,
) -> Result<Self, ExecServerError> {
let (rpc_client, mut events_rx) = RpcClient::new(connection);
let inner = Arc::new_cyclic(|weak| {
let weak = weak.clone();
let reader_task = tokio::spawn(async move {
while let Some(event) = events_rx.recv().await {
match event {
RpcClientEvent::Notification(notification) => {
if let Some(inner) = weak.upgrade()
&& let Err(err) =
handle_server_notification(&inner, notification).await
{
let message = record_disconnected(
&inner,
format!("exec-server notification handling failed: {err}"),
);
fail_all_in_flight_work(&inner, message).await;
return;
}
}
RpcClientEvent::Disconnected { reason } => {
if let Some(inner) = weak.upgrade() {
let message = record_disconnected(
&inner,
disconnected_message(reason.as_deref()),
);
fail_all_in_flight_work(&inner, message).await;
}
return;
}
}
}
});
Inner {
client: rpc_client,
sessions: ArcSwap::from_pointee(HashMap::new()),
sessions_write_lock: Mutex::new(()),
disconnected: OnceLock::new(),
http_body_streams: ArcSwap::from_pointee(HashMap::new()),
http_body_stream_failures: ArcSwap::from_pointee(HashMap::new()),
http_body_streams_write_lock: Mutex::new(()),
http_body_stream_next_id: AtomicU64::new(1),
session_id: std::sync::RwLock::new(None),
reader_task,
}
});
let client = Self { inner };
client.initialize(options).await?;
Ok(client)
Self::connect_with_recovery(connection, options, /*reconnect_strategy*/ None).await
}
async fn notify_initialized(&self) -> Result<(), ExecServerError> {
self.inner
.client
.notify(INITIALIZED_METHOD, &serde_json::json!({}))
.await
.map_err(ExecServerError::Json)
pub(crate) async fn connect_with_recovery(
connection: JsonRpcConnection,
options: ExecServerClientConnectOptions,
reconnect_strategy: Option<ExecServerReconnectStrategy>,
) -> Result<Self, ExecServerError> {
let (rpc_client, events_rx) = RpcClient::new(connection);
let rpc_client = Arc::new(rpc_client);
let session_id = OnceLock::new();
let (connection_changed, _connection_changed_rx) = watch::channel(());
let inner = Arc::new(Inner {
connection: StdMutex::new(ConnectionState {
status: ConnectionStatus::Connected(Arc::clone(&rpc_client)),
active_process_starts: 0,
}),
connection_changed,
sessions: ArcSwap::from_pointee(HashMap::new()),
sessions_write_lock: StdMutex::new(()),
http_body_streams: ArcSwap::from_pointee(HashMap::new()),
http_body_stream_failures: ArcSwap::from_pointee(HashMap::new()),
http_body_streams_write_lock: Mutex::new(()),
http_body_stream_next_id: AtomicU64::new(1),
session_id,
reconnect_strategy,
});
let client = Self { inner };
// An explicit resume can redirect notifications from running processes
// before initialize returns. Drain them immediately so a burst cannot
// fill the bounded event channel and block the initialize response.
client.spawn_rpc_reader(&rpc_client, events_rx);
client.initialize_rpc(&rpc_client, options).await?;
Ok(client)
}
async fn call<P, T>(&self, method: &str, params: &P) -> Result<T, ExecServerError>
@@ -603,24 +659,28 @@ impl ExecServerClient {
P: serde::Serialize,
T: serde::de::DeserializeOwned,
{
// Reject new work before allocating a JSON-RPC request id. MCP tool
// calls, process writes, and fs operations all pass through here, so
// this is the shared low-level failure path after environment disconnect.
if let Some(error) = self.inner.disconnected_error() {
return Err(error);
}
let rpc_client = self.inner.rpc_client().await?;
self.call_rpc(&rpc_client, method, params).await
}
match self.inner.client.call(method, params).await {
async fn call_rpc<P, T>(
&self,
rpc_client: &Arc<RpcClient>,
method: &str,
params: &P,
) -> Result<T, ExecServerError>
where
P: serde::Serialize,
T: serde::de::DeserializeOwned,
{
match rpc_client.call(method, params).await {
Ok(response) => Ok(response),
Err(error) => {
let error = ExecServerError::from(error);
if is_transport_closed_error(&error) {
// A call can race with disconnect after the preflight
// check. Only the reader task drains sessions so queued
// process notifications stay ordered before disconnect.
let message = disconnected_message(/*reason*/ None);
let message = record_disconnected(&self.inner, message);
Err(ExecServerError::Disconnected(message))
Err(ExecServerError::Disconnected(disconnected_message(
/*reason*/ None,
)))
} else {
Err(error)
}
@@ -629,6 +689,23 @@ impl ExecServerClient {
}
}
async fn cleanup_process_start(
client: &ExecServerClient,
process_id: &ProcessId,
state: &Arc<SessionState>,
) {
loop {
match client.terminate(process_id).await {
Ok(_) => break,
Err(error) if is_transport_closed_error(&error) && !client.inner.is_failed() => {
continue;
}
Err(_) => break,
}
}
client.inner.remove_session_if(process_id, state);
}
impl From<RpcCallError> for ExecServerError {
fn from(value: RpcCallError) -> Self {
match value {
@@ -643,7 +720,7 @@ impl From<RpcCallError> for ExecServerError {
}
impl SessionState {
fn new() -> Self {
fn new(recoverable: bool) -> Self {
let (wake_tx, _wake_rx) = watch::channel(0);
Self {
wake_tx,
@@ -652,7 +729,7 @@ impl SessionState {
PROCESS_EVENT_RETAINED_BYTES,
),
ordered_events: StdMutex::new(OrderedSessionEvents::default()),
failure: Mutex::new(None),
recoverable: AtomicBool::new(recoverable),
}
}
@@ -665,8 +742,8 @@ impl SessionState {
}
fn note_change(&self, seq: u64) {
let next = (*self.wake_tx.borrow()).max(seq);
let _ = self.wake_tx.send(next);
self.wake_tx
.send_modify(|current| *current = (*current).max(seq));
}
/// Publishes a process event only when all earlier sequenced events have
@@ -682,55 +759,61 @@ impl SessionState {
return false;
};
let mut ready = Vec::new();
let mut ordered_events = self
.ordered_events
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
// We have already delivered this sequence number or moved past it,
// so accepting it again would duplicate output or lifecycle events.
if ordered_events.failure.is_some()
|| ordered_events.closed_published
|| seq <= ordered_events.last_published_seq
{
let mut ordered_events = self
.ordered_events
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
// We have already delivered this sequence number or moved past it,
// so accepting it again would duplicate output or lifecycle events.
if seq <= ordered_events.last_published_seq {
return false;
}
ordered_events.pending.entry(seq).or_insert(event);
loop {
let next_seq = ordered_events.last_published_seq + 1;
let Some(event) = ordered_events.pending.remove(&next_seq) else {
break;
};
ordered_events.last_published_seq += 1;
ready.push(event);
}
return false;
}
ordered_events.pending.entry(seq).or_insert(event);
self.publish_ready(&mut ordered_events)
}
fn publish_ready(&self, ordered_events: &mut OrderedSessionEvents) -> bool {
let mut published_closed = false;
for event in ready {
published_closed |= matches!(&event, ExecProcessEvent::Closed { .. });
loop {
let next_seq = ordered_events.last_published_seq.saturating_add(1);
let Some(event) = ordered_events.pending.remove(&next_seq) else {
break;
};
ordered_events.last_published_seq = next_seq;
ordered_events.exit_published |= matches!(&event, ExecProcessEvent::Exited { .. });
let is_closed = matches!(&event, ExecProcessEvent::Closed { .. });
ordered_events.closed_published |= is_closed;
published_closed |= is_closed;
self.events.publish(event);
}
published_closed
}
async fn set_failure(&self, message: String) {
let mut failure = self.failure.lock().await;
let should_publish = failure.is_none();
if should_publish {
*failure = Some(message.clone());
}
drop(failure);
let next = (*self.wake_tx.borrow()).saturating_add(1);
let _ = self.wake_tx.send(next);
if should_publish {
let _ = self.publish_ordered_event(ExecProcessEvent::Failed(message));
fn set_failure(&self, message: String) {
let mut ordered_events = self
.ordered_events
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if ordered_events.failure.is_some() || ordered_events.closed_published {
return;
}
ordered_events.failure = Some(message.clone());
ordered_events.pending.clear();
self.events.publish(ExecProcessEvent::Failed(message));
drop(ordered_events);
self.wake_tx
.send_modify(|current| *current = current.saturating_add(1));
}
async fn failed_response(&self) -> Option<ReadResponse> {
self.failure
fn failed_response(&self) -> Option<ReadResponse> {
self.ordered_events
.lock()
.await
.unwrap_or_else(std::sync::PoisonError::into_inner)
.failure
.clone()
.map(|message| self.synthesized_failure(message))
}
@@ -767,27 +850,37 @@ impl Session {
max_bytes: Option<usize>,
wait_ms: Option<u64>,
) -> Result<ReadResponse, ExecServerError> {
if let Some(response) = self.state.failed_response().await {
return Ok(response);
}
match self
.client
.read(ReadParams {
process_id: self.process_id.clone(),
after_seq,
max_bytes,
wait_ms,
})
.await
{
Ok(response) => Ok(response),
Err(err) if is_transport_closed_error(&err) => {
let message = disconnected_message(/*reason*/ None);
self.state.set_failure(message.clone()).await;
Ok(self.state.synthesized_failure(message))
loop {
if let Some(response) = self.state.failed_response() {
return Ok(response);
}
match self
.client
.read(ReadParams {
process_id: self.process_id.clone(),
after_seq,
max_bytes,
wait_ms,
})
.await
{
Ok(response) => return Ok(response),
Err(error)
if is_transport_closed_error(&error) && !self.client.inner.is_failed() =>
{
continue;
}
Err(error) if is_transport_closed_error(&error) => {
if let Some(response) = self.state.failed_response() {
return Ok(response);
}
let message = error.to_string();
self.state.set_failure(message.clone());
return Ok(self.state.synthesized_failure(message));
}
Err(error) => return Err(error),
}
Err(err) => Err(err),
}
}
@@ -805,40 +898,31 @@ impl Session {
}
pub(crate) async fn unregister(&self) {
self.client.unregister_session(&self.process_id).await;
self.client
.inner
.remove_session_if(&self.process_id, &self.state);
}
}
impl Inner {
fn disconnected_error(&self) -> Option<ExecServerError> {
self.disconnected
.get()
.cloned()
.map(ExecServerError::Disconnected)
}
fn set_disconnected(&self, message: String) -> Option<String> {
match self.disconnected.set(message.clone()) {
Ok(()) => Some(message),
Err(_) => None,
}
}
fn get_session(&self, process_id: &ProcessId) -> Option<Arc<SessionState>> {
self.sessions.load().get(process_id).cloned()
}
async fn insert_session(
fn insert_session(
&self,
process_id: &ProcessId,
session: Arc<SessionState>,
) -> Result<(), ExecServerError> {
let _sessions_write_guard = self.sessions_write_lock.lock().await;
let _sessions_write_guard = self
.sessions_write_lock
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
// Do not register a process session that can never receive environment
// notifications. Without this check, remote MCP startup could create a
// dead session and wait for process output that will never arrive.
if let Some(error) = self.disconnected_error() {
return Err(error);
if let Some(message) = self.failure_message() {
return Err(ExecServerError::Disconnected(message));
}
let sessions = self.sessions.load();
if sessions.contains_key(process_id) {
@@ -852,19 +936,28 @@ impl Inner {
Ok(())
}
async fn remove_session(&self, process_id: &ProcessId) -> Option<Arc<SessionState>> {
let _sessions_write_guard = self.sessions_write_lock.lock().await;
fn remove_session_if(&self, process_id: &ProcessId, expected: &Arc<SessionState>) {
let _sessions_write_guard = self
.sessions_write_lock
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let sessions = self.sessions.load();
let session = sessions.get(process_id).cloned();
session.as_ref()?;
if !sessions
.get(process_id)
.is_some_and(|session| Arc::ptr_eq(session, expected))
{
return;
}
let mut next_sessions = sessions.as_ref().clone();
next_sessions.remove(process_id);
self.sessions.store(Arc::new(next_sessions));
session
}
async fn take_all_sessions(&self) -> HashMap<ProcessId, Arc<SessionState>> {
let _sessions_write_guard = self.sessions_write_lock.lock().await;
fn take_all_sessions(&self) -> HashMap<ProcessId, Arc<SessionState>> {
let _sessions_write_guard = self
.sessions_write_lock
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let sessions = self.sessions.load();
let drained_sessions = sessions.as_ref().clone();
self.sessions.store(Arc::new(HashMap::new()));
@@ -892,31 +985,20 @@ fn is_transport_closed_error(error: &ExecServerError) -> bool {
)
}
fn record_disconnected(inner: &Arc<Inner>, message: String) -> String {
// The first observer records the canonical disconnect reason. Session
// draining stays with the reader task so it can preserve notification
// ordering before publishing the terminal failure.
if let Some(message) = inner.set_disconnected(message.clone()) {
message
} else {
inner.disconnected.get().cloned().unwrap_or(message)
}
}
async fn fail_all_sessions(inner: &Arc<Inner>, message: String) {
let sessions = inner.take_all_sessions().await;
fn fail_all_sessions(inner: &Arc<Inner>, message: String) {
let sessions = inner.take_all_sessions();
for (_, session) in sessions {
// Sessions synthesize a closed read response and emit a pushed Failed
// event. That covers both polling consumers and streaming consumers
// such as environment-backed MCP stdio.
session.set_failure(message.clone()).await;
session.set_failure(message.clone());
}
}
/// Fails all in-flight work that depends on the shared JSON-RPC transport.
async fn fail_all_in_flight_work(inner: &Arc<Inner>, message: String) {
fail_all_sessions(inner, message.clone()).await;
fail_all_sessions(inner, message.clone());
inner.fail_all_http_body_streams(message).await;
}
@@ -937,7 +1019,7 @@ async fn handle_server_notification(
chunk: params.chunk,
}));
if published_closed {
inner.remove_session(&params.process_id).await;
inner.remove_session_if(&params.process_id, &session);
}
}
}
@@ -951,7 +1033,7 @@ async fn handle_server_notification(
exit_code: params.exit_code,
});
if published_closed {
inner.remove_session(&params.process_id).await;
inner.remove_session_if(&params.process_id, &session);
}
}
}
@@ -966,7 +1048,7 @@ async fn handle_server_notification(
let published_closed =
session.publish_ordered_event(ExecProcessEvent::Closed { seq: params.seq });
if published_closed {
inner.remove_session(&params.process_id).await;
inner.remove_session_if(&params.process_id, &session);
}
}
}
@@ -1020,6 +1102,7 @@ mod tests {
#[cfg(not(windows))]
use crate::client_api::DEFAULT_REMOTE_EXEC_SERVER_INITIALIZE_TIMEOUT;
use crate::client_api::ExecServerTransportParams;
use crate::client_api::RemoteExecServerConnectArgs;
use crate::client_api::StdioExecServerCommand;
use crate::client_api::StdioExecServerConnectArgs;
use crate::connection::JsonRpcConnection;
@@ -1136,19 +1219,6 @@ mod tests {
}
}
async fn wait_for_disconnect(client: &ExecServerClient) {
timeout(Duration::from_secs(1), async {
loop {
if client.is_disconnected() {
return;
}
tokio::task::yield_now().await;
}
})
.await
.expect("client should observe disconnect");
}
#[cfg(not(windows))]
#[tokio::test]
async fn connect_stdio_command_initializes_json_rpc_client() {
@@ -1567,7 +1637,7 @@ mod tests {
}
#[tokio::test]
async fn remote_websocket_client_replaces_disconnected_client_with_fresh_session() {
async fn remote_websocket_client_resumes_session() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
@@ -1575,28 +1645,27 @@ mod tests {
"ws://{}",
listener.local_addr().expect("listener should have address")
);
let server = tokio::spawn({
async move {
let mut first = accept_websocket(&listener).await;
complete_websocket_initialize(
&mut first,
"session-1",
/*expected_resume_session_id*/ None,
)
.await;
first
.close(None)
.await
.expect("first websocket should close");
let (resumed_tx, resumed_rx) = oneshot::channel();
let (finish_tx, finish_rx) = oneshot::channel();
let server = tokio::spawn(async move {
let mut first = accept_websocket(&listener).await;
complete_websocket_initialize(
&mut first,
"session-1",
/*expected_resume_session_id*/ None,
)
.await;
first.close(None).await.expect("websocket should close");
let mut second = accept_websocket(&listener).await;
complete_websocket_initialize(
&mut second,
"session-2",
/*expected_resume_session_id*/ None,
)
.await;
}
let mut resumed = accept_websocket(&listener).await;
complete_websocket_initialize(
&mut resumed,
"session-1",
/*expected_resume_session_id*/ Some("session-1"),
)
.await;
resumed_tx.send(()).expect("resume should signal");
finish_rx.await.expect("test should finish");
});
let client = LazyRemoteExecServerClient::new(ExecServerTransportParams::WebSocketUrl {
@@ -1604,16 +1673,103 @@ mod tests {
connect_timeout: Duration::from_secs(1),
initialize_timeout: Duration::from_secs(1),
});
let first = client.get().await.expect("first client should connect");
wait_for_disconnect(&first).await;
let stable_client = client.get().await.expect("client should connect");
timeout(Duration::from_secs(1), resumed_rx)
.await
.expect("session resume should not time out")
.expect("session resume should signal");
let reused_client = client.get().await.expect("client should stay connected");
assert_eq!(stable_client.session_id().as_deref(), Some("session-1"));
assert!(Arc::ptr_eq(&stable_client.inner, &reused_client.inner));
finish_tx.send(()).expect("test should finish");
server.await.expect("server task should finish");
}
let (replacement_a, replacement_b) = tokio::join!(client.get(), client.get());
let replacement_a = replacement_a.expect("first replacement should connect");
let replacement_b = replacement_b.expect("second replacement should reuse client");
assert_eq!(replacement_a.session_id().as_deref(), Some("session-2"));
assert_eq!(replacement_b.session_id().as_deref(), Some("session-2"));
assert!(Arc::ptr_eq(&replacement_a.inner, &replacement_b.inner));
#[tokio::test]
async fn explicit_resume_drains_notifications_before_initialize_response() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let websocket_url = format!(
"ws://{}",
listener.local_addr().expect("listener should have address")
);
let (initialized_tx, initialized_rx) = oneshot::channel();
let (finish_tx, finish_rx) = oneshot::channel();
let server = tokio::spawn(async move {
let mut websocket = accept_websocket(&listener).await;
let initialize = read_jsonrpc_websocket(&mut websocket).await;
let request = match initialize {
JSONRPCMessage::Request(request) if request.method == INITIALIZE_METHOD => request,
other => panic!("expected initialize request, got {other:?}"),
};
let params: crate::protocol::InitializeParams =
serde_json::from_value(request.params.expect("initialize params should exist"))
.expect("initialize params should deserialize");
assert_eq!(params.resume_session_id.as_deref(), Some("session-1"));
for seq in 1..=256 {
write_jsonrpc_websocket(
&mut websocket,
JSONRPCMessage::Notification(JSONRPCNotification {
method: EXEC_OUTPUT_DELTA_METHOD.to_string(),
params: Some(
serde_json::to_value(ExecOutputDeltaNotification {
process_id: ProcessId::from("busy-process"),
seq,
stream: ExecOutputStream::Stdout,
chunk: b"output".to_vec().into(),
})
.expect("output notification should serialize"),
),
}),
)
.await;
}
write_jsonrpc_websocket(
&mut websocket,
JSONRPCMessage::Response(JSONRPCResponse {
id: request.id,
result: serde_json::to_value(InitializeResponse {
session_id: "session-1".to_string(),
})
.expect("initialize response should serialize"),
}),
)
.await;
let initialized = read_jsonrpc_websocket(&mut websocket).await;
match initialized {
JSONRPCMessage::Notification(notification)
if notification.method == INITIALIZED_METHOD => {}
other => panic!("expected initialized notification, got {other:?}"),
}
initialized_tx
.send(())
.expect("initialized notification should signal");
finish_rx.await.expect("test should finish");
});
let client = timeout(
Duration::from_secs(1),
ExecServerClient::connect_websocket(RemoteExecServerConnectArgs {
websocket_url,
client_name: "test-client".to_string(),
connect_timeout: Duration::from_secs(1),
initialize_timeout: Duration::from_secs(1),
resume_session_id: Some("session-1".to_string()),
}),
)
.await
.expect("explicit resume should not time out")
.expect("explicit resume should connect");
assert_eq!(client.session_id().as_deref(), Some("session-1"));
timeout(Duration::from_secs(1), initialized_rx)
.await
.expect("initialized notification should not time out")
.expect("initialized notification should signal");
finish_tx.send(()).expect("test should finish");
server.await.expect("server task should finish");
}
@@ -42,6 +42,7 @@ impl ExecServerClient {
&self,
mut params: HttpRequestParams,
) -> Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError> {
let rpc_client = self.inner.rpc_client().await?;
params.stream_response = true;
let request_id = self.inner.next_http_body_stream_request_id();
params.request_id = request_id.clone();
@@ -51,7 +52,10 @@ impl ExecServerClient {
.await?;
let mut registration =
HttpBodyStreamRegistration::new(Arc::clone(&self.inner), request_id.clone());
let response = match self.call(HTTP_REQUEST_METHOD, &params).await {
let response = match self
.call_rpc(&rpc_client, HTTP_REQUEST_METHOD, &params)
.await
{
Ok(response) => response,
Err(error) => {
self.inner.remove_http_body_stream(&request_id).await;
+516
View File
@@ -0,0 +1,516 @@
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::Instant;
use tokio::time::sleep;
use tokio::time::timeout_at;
use super::ConnectionStatus;
use super::ExecServerClient;
use super::ExecServerError;
use super::Inner;
use super::OrderedSessionEvents;
use super::SessionState;
use super::disconnected_message;
use super::fail_all_in_flight_work;
use super::handle_server_notification;
use super::is_transport_closed_error;
use crate::process::ExecProcessEvent;
use crate::protocol::EXEC_READ_METHOD;
use crate::protocol::EXEC_TERMINATE_METHOD;
use crate::protocol::ReadParams;
use crate::protocol::ReadResponse;
use crate::protocol::TerminateParams;
use crate::protocol::TerminateResponse;
use crate::rpc::RpcClient;
use crate::rpc::RpcClientEvent;
use crate::rpc::SESSION_ALREADY_ATTACHED_ERROR_CODE;
#[cfg(test)]
const SESSION_RECOVERY_TIMEOUT: Duration = Duration::from_millis(500);
#[cfg(not(test))]
// Leave margin inside the server's 30-second retention windows because the
// client and server start their disconnect clocks independently.
const SESSION_RECOVERY_TIMEOUT: Duration = Duration::from_secs(25);
const SESSION_RECOVERY_RETRY_INTERVAL: Duration = Duration::from_millis(100);
impl SessionState {
fn last_published_seq(&self) -> u64 {
self.ordered_events
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.last_published_seq
}
fn recover_events(&self, response: ReadResponse) -> Result<bool, ExecServerError> {
let ReadResponse {
chunks,
next_seq,
exited,
exit_code,
closed,
failure,
} = response;
if let Some(message) = failure {
return Err(ExecServerError::Protocol(format!(
"process failed while recovering: {message}"
)));
}
let target_seq = next_seq.saturating_sub(1);
let published_closed = {
let mut ordered_events = self
.ordered_events
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if ordered_events.failure.is_some()
|| ordered_events.closed_published
|| target_seq <= ordered_events.last_published_seq
{
return Ok(false);
}
for chunk in chunks {
if chunk.seq > ordered_events.last_published_seq {
ordered_events
.pending
.entry(chunk.seq)
.or_insert(ExecProcessEvent::Output(chunk));
}
}
if closed {
match ordered_events.pending.get(&target_seq) {
Some(ExecProcessEvent::Closed { .. }) => {}
Some(_) => {
return Err(ExecServerError::Protocol(format!(
"process close sequence {target_seq} conflicts with recovered output"
)));
}
None => {
ordered_events
.pending
.insert(target_seq, ExecProcessEvent::Closed { seq: target_seq });
}
}
}
let exit_known = ordered_events.exit_published
|| ordered_events
.pending
.range(..=target_seq)
.any(|(_, event)| matches!(event, ExecProcessEvent::Exited { .. }));
let event_count = target_seq - ordered_events.last_published_seq;
let retained_count = ordered_events
.pending
.range(ordered_events.last_published_seq.saturating_add(1)..=target_seq)
.count() as u64;
let missing_count = event_count.saturating_sub(retained_count);
if exited && !exit_known {
if missing_count != 1 {
return Err(recovery_gap_error(target_seq));
}
let seq = first_missing_seq(&ordered_events, target_seq);
let exit_code = exit_code.ok_or_else(|| {
ExecServerError::Protocol(
"recovering exited process did not include its exit code".to_string(),
)
})?;
ordered_events
.pending
.insert(seq, ExecProcessEvent::Exited { seq, exit_code });
} else if missing_count != 0 {
return Err(recovery_gap_error(target_seq));
}
self.publish_ready(&mut ordered_events)
};
self.note_change(target_seq);
Ok(published_closed)
}
}
fn first_missing_seq(events: &OrderedSessionEvents, target_seq: u64) -> u64 {
let mut expected = events.last_published_seq.saturating_add(1);
for seq in events
.pending
.range(expected..=target_seq)
.map(|(seq, _)| *seq)
{
if seq != expected {
break;
}
expected = expected.saturating_add(1);
}
expected
}
fn recovery_gap_error(target_seq: u64) -> ExecServerError {
ExecServerError::Protocol(format!(
"process events are no longer retained while recovering through sequence {target_seq}"
))
}
impl Inner {
pub(super) async fn rpc_client(self: &Arc<Self>) -> Result<Arc<RpcClient>, ExecServerError> {
let mut connection_changed = self.connection_changed.subscribe();
loop {
if let Some(message) = self.failure_message() {
return Err(ExecServerError::Disconnected(message));
}
let rpc_client = {
let connection = self
.connection
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
match &connection.status {
ConnectionStatus::Connected(rpc_client) => Some(Arc::clone(rpc_client)),
ConnectionStatus::Recovering | ConnectionStatus::Failed(_) => None,
}
};
let Some(rpc_client) = rpc_client else {
let _ = connection_changed.changed().await;
continue;
};
if !rpc_client.is_disconnected() {
return Ok(rpc_client);
}
let _ = connection_changed.changed().await;
}
}
pub(super) fn begin_process_start(&self, expected: &Arc<RpcClient>) -> bool {
let mut connection = self
.connection
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let ConnectionStatus::Connected(current) = &connection.status else {
return false;
};
if !Arc::ptr_eq(current, expected) || expected.is_disconnected() {
return false;
}
connection.active_process_starts += 1;
true
}
pub(super) fn finish_process_start(&self) {
{
let mut connection = self
.connection
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if connection.active_process_starts == 0 {
tracing::error!("finished an exec-server process start that was not active");
return;
}
connection.active_process_starts -= 1;
}
self.notify_connection_changed();
}
pub(super) fn is_failed(&self) -> bool {
self.failure_message().is_some()
}
pub(super) fn failure_message(&self) -> Option<String> {
let connection = self
.connection
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
match &connection.status {
ConnectionStatus::Failed(message) => Some(message.clone()),
ConnectionStatus::Connected(_) | ConnectionStatus::Recovering => None,
}
}
fn request_recovery(
self: &Arc<Self>,
failed_rpc_client: Arc<RpcClient>,
disconnect_message: String,
) {
let should_recover = {
let mut connection = self
.connection
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
match &connection.status {
ConnectionStatus::Connected(current)
if Arc::ptr_eq(current, &failed_rpc_client) =>
{
connection.status = ConnectionStatus::Recovering;
true
}
ConnectionStatus::Connected(_)
| ConnectionStatus::Recovering
| ConnectionStatus::Failed(_) => false,
}
};
if !should_recover {
return;
}
self.notify_connection_changed();
let inner = Arc::clone(self);
tokio::spawn(async move {
inner.recover(disconnect_message).await;
});
}
async fn recover(self: &Arc<Self>, disconnect_message: String) {
let deadline = Instant::now() + SESSION_RECOVERY_TIMEOUT;
self.fail_all_http_body_streams(disconnect_message.clone())
.await;
if timeout_at(deadline, self.wait_for_process_starts())
.await
.is_err()
{
let message = format!(
"{disconnect_message}; failed to resume exec-server session: recovery timed out after {SESSION_RECOVERY_TIMEOUT:?}"
);
self.fail(message).await;
return;
}
if self.reconnect_strategy.is_none() {
self.fail(disconnect_message).await;
return;
}
let Some(session_id) = self.session_id.get().cloned() else {
let message = format!(
"{disconnect_message}; failed to resume exec-server session: missing session id"
);
self.fail(message).await;
return;
};
let last_error = loop {
match timeout_at(deadline, self.resume_once(&session_id)).await {
Ok(Ok(candidate)) if !candidate.is_disconnected() => {
if self.install_recovered_client(candidate) {
return;
}
}
Ok(Ok(_)) => {}
Ok(Err(error)) if !is_retryable_recovery_error(&error) => {
break error.to_string();
}
Ok(Err(_)) => {}
Err(_) => {
break format!("recovery timed out after {SESSION_RECOVERY_TIMEOUT:?}");
}
}
let now = Instant::now();
if now >= deadline {
break format!("recovery timed out after {SESSION_RECOVERY_TIMEOUT:?}");
}
sleep(SESSION_RECOVERY_RETRY_INTERVAL.min(deadline - now)).await;
};
let message =
format!("{disconnect_message}; failed to resume exec-server session: {last_error}");
self.fail(message).await;
}
async fn wait_for_process_starts(&self) {
let mut connection_changed = self.connection_changed.subscribe();
loop {
let starts_are_done = self
.connection
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.active_process_starts
== 0;
if starts_are_done {
return;
}
let _ = connection_changed.changed().await;
}
}
fn install_recovered_client(&self, rpc_client: Arc<RpcClient>) -> bool {
let installed = {
let mut connection = self
.connection
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if !matches!(connection.status, ConnectionStatus::Recovering)
|| rpc_client.is_disconnected()
{
false
} else {
connection.status = ConnectionStatus::Connected(rpc_client);
true
}
};
if installed {
self.notify_connection_changed();
}
installed
}
fn notify_connection_changed(&self) {
self.connection_changed.send_replace(());
}
async fn resume_once(
self: &Arc<Self>,
session_id: &str,
) -> Result<Arc<RpcClient>, ExecServerError> {
let reconnect_strategy = self
.reconnect_strategy
.as_ref()
.ok_or_else(|| ExecServerError::Protocol("missing reconnect strategy".to_string()))?;
let (connection, options) = reconnect_strategy.resume(session_id).await?;
let (rpc_client, events_rx) = RpcClient::new(connection);
let rpc_client = Arc::new(rpc_client);
let client = ExecServerClient {
inner: Arc::clone(self),
};
// Resuming a session redirects notifications from its running processes
// to this connection during initialize. Drain them immediately so a
// burst cannot fill the bounded event channel and block the initialize
// response behind it.
client.spawn_rpc_reader(&rpc_client, events_rx);
client.initialize_rpc(&rpc_client, options).await?;
self.recover_processes(&rpc_client).await?;
Ok(rpc_client)
}
async fn recover_processes(
self: &Arc<Self>,
rpc_client: &RpcClient,
) -> Result<(), ExecServerError> {
let sessions = self.sessions.load_full();
for (process_id, session) in sessions.iter() {
if !session.recoverable.load(Ordering::Acquire) {
continue;
}
let response = rpc_client
.call::<_, ReadResponse>(
EXEC_READ_METHOD,
&ReadParams {
process_id: process_id.clone(),
after_seq: Some(session.last_published_seq()),
max_bytes: None,
wait_ms: Some(0),
},
)
.await
.map_err(ExecServerError::from);
let recovered = match response {
Ok(response) => session.recover_events(response),
Err(error) if is_transport_closed_error(&error) => return Err(error),
Err(error) => Err(error),
};
match recovered {
Ok(true) => self.remove_session_if(process_id, session),
Ok(false) => {}
Err(error) => {
let terminated: Result<TerminateResponse, ExecServerError> = rpc_client
.call(
EXEC_TERMINATE_METHOD,
&TerminateParams {
process_id: process_id.clone(),
},
)
.await
.map_err(ExecServerError::from);
if let Err(terminate_error) = terminated
&& is_transport_closed_error(&terminate_error)
{
return Err(terminate_error);
}
self.remove_session_if(process_id, session);
session.set_failure(format!("failed to recover process {process_id}: {error}"));
}
}
}
Ok(())
}
async fn fail(self: &Arc<Self>, message: String) {
let (message, newly_failed) = {
let mut connection = self
.connection
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
match &connection.status {
ConnectionStatus::Failed(existing) => (existing.clone(), false),
ConnectionStatus::Connected(_) | ConnectionStatus::Recovering => {
connection.status = ConnectionStatus::Failed(message.clone());
(message, true)
}
}
};
if newly_failed {
self.notify_connection_changed();
fail_all_in_flight_work(self, message.clone()).await;
}
}
}
impl ExecServerClient {
pub(super) fn spawn_rpc_reader(
&self,
rpc_client: &Arc<RpcClient>,
mut events_rx: mpsc::Receiver<RpcClientEvent>,
) {
let inner = Arc::downgrade(&self.inner);
let rpc_client = Arc::downgrade(rpc_client);
tokio::spawn(async move {
while let Some(event) = events_rx.recv().await {
let (Some(inner), Some(rpc_client)) = (inner.upgrade(), rpc_client.upgrade())
else {
return;
};
match event {
RpcClientEvent::Notification(notification) => {
if let Err(error) = handle_server_notification(&inner, notification).await {
rpc_client.close_transport().await;
inner.request_recovery(
rpc_client,
format!("exec-server notification handling failed: {error}"),
);
return;
}
}
RpcClientEvent::Disconnected { reason } => {
inner.request_recovery(rpc_client, disconnected_message(reason.as_deref()));
return;
}
}
}
});
}
}
fn is_retryable_recovery_error(error: &ExecServerError) -> bool {
is_transport_closed_error(error)
|| matches!(
error,
ExecServerError::WebSocketConnectTimeout { .. }
| ExecServerError::WebSocketConnect { .. }
| ExecServerError::InitializeTimedOut { .. }
)
|| matches!(
error,
ExecServerError::EnvironmentRegistryRequest(error)
if error.is_connect() || error.is_timeout()
)
|| matches!(
error,
ExecServerError::EnvironmentRegistryHttp { status, .. }
if status.is_server_error()
|| *status == reqwest::StatusCode::REQUEST_TIMEOUT
|| *status == reqwest::StatusCode::TOO_MANY_REQUESTS
)
|| matches!(
error,
ExecServerError::Server { code, .. }
if *code == SESSION_ALREADY_ATTACHED_ERROR_CODE
)
}
+103 -14
View File
@@ -1,4 +1,7 @@
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::process::Command;
@@ -14,12 +17,15 @@ use crate::ExecServerClient;
use crate::ExecServerError;
use crate::client_api::DEFAULT_REMOTE_EXEC_SERVER_CONNECT_TIMEOUT;
use crate::client_api::DEFAULT_REMOTE_EXEC_SERVER_INITIALIZE_TIMEOUT;
use crate::client_api::ExecServerClientConnectOptions;
use crate::client_api::NoiseRendezvousConnectArgs;
use crate::client_api::NoiseRendezvousConnectBundle;
use crate::client_api::NoiseRendezvousConnectProvider;
use crate::client_api::RemoteExecServerConnectArgs;
use crate::client_api::StdioExecServerCommand;
use crate::client_api::StdioExecServerConnectArgs;
use crate::connection::JsonRpcConnection;
use crate::noise_channel::NoiseChannelIdentity;
use crate::noise_relay::NoiseHarnessConnectionArgs;
use crate::noise_relay::noise_harness_connection_from_websocket;
use crate::noise_relay::noise_relay_websocket_config;
@@ -27,6 +33,57 @@ use crate::relay::harness_connection_from_websocket;
const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment";
/// Reopens the transport for one logical exec-server client session.
///
/// URL connections reuse their configured endpoint. Noise connections retain
/// the harness identity but fetch a fresh single-use authorization bundle for
/// every physical connection attempt.
#[derive(Clone)]
pub(crate) enum ExecServerReconnectStrategy {
WebSocket(RemoteExecServerConnectArgs),
NoiseRendezvous {
provider: Arc<dyn NoiseRendezvousConnectProvider>,
identity: NoiseChannelIdentity,
client_name: String,
connect_timeout: Duration,
initialize_timeout: Duration,
},
}
impl ExecServerReconnectStrategy {
pub(crate) async fn resume(
&self,
session_id: &str,
) -> Result<(JsonRpcConnection, ExecServerClientConnectOptions), ExecServerError> {
match self {
Self::WebSocket(args) => {
let mut args = args.clone();
args.resume_session_id = Some(session_id.to_string());
let connection = ExecServerClient::open_websocket_connection(&args).await?;
Ok((connection, args.into()))
}
Self::NoiseRendezvous {
provider,
identity,
client_name,
connect_timeout,
initialize_timeout,
} => {
let bundle = provider.connect_bundle(identity.public_key()).await?;
ExecServerClient::open_noise_rendezvous_connection(NoiseRendezvousConnectArgs {
bundle,
harness_identity: identity.clone(),
client_name: client_name.clone(),
connect_timeout: *connect_timeout,
initialize_timeout: *initialize_timeout,
resume_session_id: Some(session_id.to_string()),
})
.await
}
}
}
}
impl ExecServerClient {
/// Open the selected transport and run the common JSON-RPC initialization.
/// Noise connection details are fetched here so reconnects get a fresh URL
@@ -53,16 +110,25 @@ impl ExecServerClient {
provider,
identity,
} => {
let bundle = provider.connect_bundle(identity.public_key()).await?;
Self::connect_noise_rendezvous(NoiseRendezvousConnectArgs {
bundle,
harness_identity: identity,
let reconnect_strategy = ExecServerReconnectStrategy::NoiseRendezvous {
provider: Arc::clone(&provider),
identity: identity.clone(),
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
connect_timeout: DEFAULT_REMOTE_EXEC_SERVER_CONNECT_TIMEOUT,
initialize_timeout: DEFAULT_REMOTE_EXEC_SERVER_INITIALIZE_TIMEOUT,
resume_session_id: None,
})
.await
};
let bundle = provider.connect_bundle(identity.public_key()).await?;
let (connection, options) =
Self::open_noise_rendezvous_connection(NoiseRendezvousConnectArgs {
bundle,
harness_identity: identity,
client_name: ENVIRONMENT_CLIENT_NAME.to_string(),
connect_timeout: DEFAULT_REMOTE_EXEC_SERVER_CONNECT_TIMEOUT,
initialize_timeout: DEFAULT_REMOTE_EXEC_SERVER_INITIALIZE_TIMEOUT,
resume_session_id: None,
})
.await?;
Self::connect_with_recovery(connection, options, Some(reconnect_strategy)).await
}
crate::client_api::ExecServerTransportParams::StdioCommand {
command,
@@ -82,6 +148,19 @@ impl ExecServerClient {
pub async fn connect_websocket(
args: RemoteExecServerConnectArgs,
) -> Result<Self, ExecServerError> {
let connection = Self::open_websocket_connection(&args).await?;
let options = args.clone().into();
Self::connect_with_recovery(
connection,
options,
Some(ExecServerReconnectStrategy::WebSocket(args)),
)
.await
}
pub(crate) async fn open_websocket_connection(
args: &RemoteExecServerConnectArgs,
) -> Result<JsonRpcConnection, ExecServerError> {
ensure_rustls_crypto_provider();
let websocket_url = args.websocket_url.clone();
let connect_timeout = args.connect_timeout;
@@ -102,15 +181,26 @@ impl ExecServerClient {
} else {
JsonRpcConnection::from_websocket(stream, connection_label)
};
Self::connect(connection, args.into()).await
Ok(connection)
}
/// Connect to one exec-server through an authenticated rendezvous stream.
/// Connect to one exec-server through an authenticated rendezvous stream
/// using a caller-supplied single-use authorization bundle.
///
/// The executor key is pinned before JSON-RPC starts; the websocket carries
/// only ciphertext after that.
/// only ciphertext after that. Environment-managed connections use a
/// retained [`NoiseRendezvousConnectProvider`] so recovery can fetch a fresh
/// bundle for each reconnect.
pub async fn connect_noise_rendezvous(
args: NoiseRendezvousConnectArgs,
) -> Result<Self, ExecServerError> {
let (connection, options) = Self::open_noise_rendezvous_connection(args).await?;
Self::connect(connection, options).await
}
pub(crate) async fn open_noise_rendezvous_connection(
args: NoiseRendezvousConnectArgs,
) -> Result<(JsonRpcConnection, ExecServerClientConnectOptions), ExecServerError> {
ensure_rustls_crypto_provider();
// Keep the registry-issued URL, key, and authorization together for this
// connection attempt.
@@ -164,15 +254,14 @@ impl ExecServerClient {
harness_key_authorization,
},
);
Self::connect(
Ok((
connection,
crate::client_api::ExecServerClientConnectOptions {
ExecServerClientConnectOptions {
client_name,
initialize_timeout,
resume_session_id,
},
)
.await
))
}
pub(crate) async fn connect_stdio_command(
+43 -8
View File
@@ -81,8 +81,10 @@ struct RunningProcess {
closed: bool,
}
struct ProcessStart;
enum ProcessEntry {
Starting,
Starting(Arc<ProcessStart>),
Running(Box<RunningProcess>),
}
@@ -128,7 +130,7 @@ impl LocalProcess {
processes
.drain()
.filter_map(|(_, process)| match process {
ProcessEntry::Starting => None,
ProcessEntry::Starting(_) => None,
ProcessEntry::Running(process) => Some(process),
})
.collect::<Vec<_>>()
@@ -163,6 +165,7 @@ impl LocalProcess {
))
})?;
let start = Arc::new(ProcessStart);
{
let mut process_map = self.inner.processes.lock().await;
if process_map.contains_key(&process_id) {
@@ -170,7 +173,10 @@ impl LocalProcess {
"process {process_id} already exists"
)));
}
process_map.insert(process_id.clone(), ProcessEntry::Starting);
process_map.insert(
process_id.clone(),
ProcessEntry::Starting(Arc::clone(&start)),
);
}
let env = child_env(&params);
@@ -207,7 +213,10 @@ impl LocalProcess {
Ok(spawned) => spawned,
Err(err) => {
let mut process_map = self.inner.processes.lock().await;
if matches!(process_map.get(&process_id), Some(ProcessEntry::Starting)) {
if matches!(
process_map.get(&process_id),
Some(ProcessEntry::Starting(current)) if Arc::ptr_eq(current, &start)
) {
process_map.remove(&process_id);
}
return Err(internal_error(err.to_string()));
@@ -222,6 +231,16 @@ impl LocalProcess {
);
{
let mut process_map = self.inner.processes.lock().await;
if !matches!(
process_map.get(&process_id),
Some(ProcessEntry::Starting(current)) if Arc::ptr_eq(current, &start)
) {
drop(process_map);
spawned.session.terminate();
return Err(invalid_request(format!(
"process {process_id} start was cancelled"
)));
}
process_map.insert(
process_id.clone(),
ProcessEntry::Running(Box::new(RunningProcess {
@@ -320,7 +339,9 @@ impl LocalProcess {
break;
}
}
if params.max_bytes.is_none() {
next_seq = process.next_seq;
}
(
ReadResponse {
chunks,
@@ -408,7 +429,7 @@ impl LocalProcess {
.signal(pty_process_signal(params.signal))
.map_err(|err| internal_error(format!("failed to signal process: {err}")))?
}
Some(ProcessEntry::Starting) | None => {}
Some(ProcessEntry::Starting(_)) | None => {}
}
}
@@ -420,7 +441,7 @@ impl LocalProcess {
params: TerminateParams,
) -> Result<TerminateResponse, JSONRPCErrorError> {
let running = {
let process_map = self.inner.processes.lock().await;
let mut process_map = self.inner.processes.lock().await;
match process_map.get(&params.process_id) {
Some(ProcessEntry::Running(process)) => {
if process.exit_code.is_some() {
@@ -429,7 +450,11 @@ impl LocalProcess {
process.session.terminate();
true
}
Some(ProcessEntry::Starting) | None => false,
Some(ProcessEntry::Starting(_)) => {
process_map.remove(&params.process_id);
true
}
None => false,
}
};
@@ -915,6 +940,16 @@ mod tests {
)
.await
.expect("process should close");
let replay_after_exit = backend
.exec_read(ReadParams {
process_id: process.process_id.clone(),
after_seq: Some(1),
max_bytes: None,
wait_ms: Some(0),
})
.await
.expect("closed process should remain readable");
assert_eq!(replay_after_exit.next_seq, 4);
backend.shutdown().await;
}
+1 -6
View File
@@ -35,13 +35,8 @@ impl RemoteProcess {
&self,
params: ExecParams,
) -> Result<StartedExecProcess, crate::ExecServerError> {
let process_id = params.process_id.clone();
let client = self.client.get().await?;
let session = client.register_session(&process_id).await?;
if let Err(err) = client.exec(params).await {
session.unregister().await;
return Err(err);
}
let session = client.start_process(params).await?;
Ok(StartedExecProcess {
process: Arc::new(RemoteExecProcess { session }),
+34 -11
View File
@@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
@@ -25,6 +26,8 @@ use crate::connection::JsonRpcConnection;
use crate::connection::JsonRpcConnectionEvent;
use crate::connection::JsonRpcTransport;
pub(crate) const SESSION_ALREADY_ATTACHED_ERROR_CODE: i64 = -32010;
#[derive(Debug)]
pub(crate) enum RpcCallError {
/// The underlying JSON-RPC transport closed before this call completed.
@@ -225,6 +228,7 @@ pub(crate) struct RpcClient {
// immediately when the socket closes, even if no JSON-RPC error response
// can be delivered for their request id.
disconnected_rx: watch::Receiver<bool>,
closed: Arc<AtomicBool>,
next_request_id: AtomicI64,
transport_tasks: Vec<JoinHandle<()>>,
transport: JsonRpcTransport,
@@ -241,9 +245,11 @@ impl RpcClient {
transport,
} = connection;
let pending = Arc::new(Mutex::new(HashMap::<RequestId, PendingRequest>::new()));
let closed = Arc::new(AtomicBool::new(false));
let (event_tx, event_rx) = mpsc::channel(128);
let pending_for_reader = Arc::clone(&pending);
let closed_for_reader = Arc::clone(&closed);
let transport_for_reader = transport.clone();
let reader_task = tokio::spawn(async move {
let disconnect_reason = loop {
@@ -269,12 +275,13 @@ impl RpcClient {
}
};
closed_for_reader.store(true, Ordering::Release);
drain_pending(&pending_for_reader).await;
let _ = event_tx
.send(RpcClientEvent::Disconnected {
reason: disconnect_reason,
})
.await;
drain_pending(&pending_for_reader).await;
transport_for_reader.terminate();
});
@@ -283,6 +290,7 @@ impl RpcClient {
write_tx,
pending,
disconnected_rx,
closed,
next_request_id: AtomicI64::new(1),
transport_tasks,
transport,
@@ -296,24 +304,31 @@ impl RpcClient {
&self,
method: &str,
params: &P,
) -> Result<(), serde_json::Error> {
let params = serde_json::to_value(params)?;
) -> Result<(), RpcCallError> {
let params = serde_json::to_value(params).map_err(RpcCallError::Json)?;
if self.closed.load(Ordering::Acquire) || *self.disconnected_rx.borrow() {
return Err(RpcCallError::Closed);
}
self.write_tx
.send(JSONRPCMessage::Notification(JSONRPCNotification {
method: method.to_string(),
params: Some(params),
}))
.await
.map_err(|_| {
serde_json::Error::io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"JSON-RPC transport closed",
))
})
.map_err(|_| RpcCallError::Closed)
}
pub(crate) fn is_disconnected(&self) -> bool {
*self.disconnected_rx.borrow()
self.closed.load(Ordering::Acquire) || *self.disconnected_rx.borrow()
}
pub(crate) async fn close_transport(&self) {
self.closed.store(true, Ordering::Release);
self.transport.terminate();
for task in &self.transport_tasks {
task.abort();
}
drain_pending(&self.pending).await;
}
pub(crate) async fn call<P, T>(&self, method: &str, params: &P) -> Result<T, RpcCallError>
@@ -328,7 +343,7 @@ impl RpcClient {
// Registering the pending request and checking disconnect must be
// atomic with the reader's drain_pending path. Otherwise a call
// can sneak in after the drain and wait forever.
if *self.disconnected_rx.borrow() {
if self.closed.load(Ordering::Acquire) || *self.disconnected_rx.borrow() {
return Err(RpcCallError::Closed);
}
pending.insert(request_id.clone(), response_tx);
@@ -417,6 +432,14 @@ pub(crate) fn invalid_request(message: String) -> JSONRPCErrorError {
}
}
pub(crate) fn session_already_attached(message: String) -> JSONRPCErrorError {
JSONRPCErrorError {
code: SESSION_ALREADY_ATTACHED_ERROR_CODE,
data: None,
message,
}
}
pub(crate) fn method_not_found(message: String) -> JSONRPCErrorError {
JSONRPCErrorError {
code: -32601,
@@ -257,7 +257,7 @@ async fn active_session_resume_is_rejected() {
.await
.expect_err("active session resume should fail");
assert_eq!(err.code, -32600);
assert_eq!(err.code, crate::rpc::SESSION_ALREADY_ATTACHED_ERROR_CODE);
assert_eq!(
err.message,
format!(
@@ -9,12 +9,13 @@ use uuid::Uuid;
use crate::rpc::RpcNotificationSender;
use crate::rpc::invalid_request;
use crate::rpc::session_already_attached;
use crate::server::process_handler::ProcessHandler;
#[cfg(test)]
const DETACHED_SESSION_TTL: Duration = Duration::from_millis(200);
#[cfg(not(test))]
const DETACHED_SESSION_TTL: Duration = Duration::from_secs(10);
const DETACHED_SESSION_TTL: Duration = Duration::from_secs(30);
pub(crate) struct SessionRegistry {
sessions: Mutex<HashMap<String, Arc<SessionEntry>>>,
@@ -82,7 +83,7 @@ impl SessionRegistry {
})?;
Ok(AttachOutcome::Expired { session_id, entry })
} else if entry.has_active_connection() {
Err(invalid_request(format!(
Err(session_already_attached(format!(
"session {session_id} is already attached to another connection"
)))
} else {
@@ -176,6 +177,7 @@ impl SessionEntry {
return false;
}
self.process.set_notification_sender(/*notifications*/ None);
attachment.current_connection_id = None;
attachment.detached_connection_id = Some(connection_id);
attachment.detached_expires_at = Some(tokio::time::Instant::now() + DETACHED_SESSION_TTL);
@@ -245,10 +247,6 @@ impl SessionHandle {
return;
}
self.entry
.process
.set_notification_sender(/*notifications*/ None);
let registry = Arc::clone(&self.registry);
let session_id = self.entry.session_id.clone();
let connection_id = self.connection_id;
@@ -14,8 +14,13 @@ use futures::StreamExt;
use tempfile::TempDir;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::io::copy_bidirectional;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio::process::Child;
use tokio::process::Command;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tokio::time::Instant;
use tokio::time::sleep;
use tokio::time::timeout;
@@ -48,6 +53,20 @@ pub(crate) struct TestCodexHelperPaths {
pub(crate) codex_linux_sandbox_exe: Option<PathBuf>,
}
pub(crate) struct DisconnectableWebSocketProxy {
websocket_url: String,
pause_tx: Option<oneshot::Sender<()>>,
blocked_connection_rx: Option<oneshot::Receiver<()>>,
resume_tx: Option<oneshot::Sender<()>>,
task: JoinHandle<()>,
}
impl Drop for DisconnectableWebSocketProxy {
fn drop(&mut self) {
self.task.abort();
}
}
pub(crate) fn test_codex_helper_paths() -> anyhow::Result<TestCodexHelperPaths> {
let (helper_binary, codex_linux_sandbox_exe) = super::current_test_binary_helper_paths()?;
Ok(TestCodexHelperPaths {
@@ -106,6 +125,35 @@ impl ExecServerHarness {
Ok(())
}
pub(crate) async fn disconnectable_websocket_proxy(
&self,
) -> anyhow::Result<DisconnectableWebSocketProxy> {
let upstream = self
.websocket_url
.strip_prefix("ws://")
.ok_or_else(|| anyhow!("exec-server websocket URL must use ws://"))?
.to_string();
let listener = TcpListener::bind("127.0.0.1:0").await?;
let websocket_url = format!("ws://{}", listener.local_addr()?);
let (pause_tx, pause_rx) = oneshot::channel();
let (blocked_connection_tx, blocked_connection_rx) = oneshot::channel();
let (resume_tx, resume_rx) = oneshot::channel();
let task = tokio::spawn(run_disconnectable_proxy(
listener,
upstream,
pause_rx,
blocked_connection_tx,
resume_rx,
));
Ok(DisconnectableWebSocketProxy {
websocket_url,
pause_tx: Some(pause_tx),
blocked_connection_rx: Some(blocked_connection_rx),
resume_tx: Some(resume_tx),
task,
})
}
pub(crate) async fn send_request(
&mut self,
method: &str,
@@ -213,6 +261,85 @@ impl ExecServerHarness {
}
}
impl DisconnectableWebSocketProxy {
pub(crate) fn websocket_url(&self) -> &str {
&self.websocket_url
}
pub(crate) async fn pause_and_disconnect(&mut self) -> anyhow::Result<()> {
self.pause_tx
.take()
.ok_or_else(|| anyhow!("disconnectable websocket proxy is already paused"))?
.send(())
.map_err(|_| anyhow!("disconnectable websocket proxy stopped"))?;
let blocked_connection_rx = self
.blocked_connection_rx
.take()
.ok_or_else(|| anyhow!("disconnectable websocket proxy is already paused"))?;
timeout(CONNECT_TIMEOUT, blocked_connection_rx)
.await
.map_err(|_| anyhow!("timed out waiting for client reconnect attempt"))?
.map_err(|_| anyhow!("disconnectable websocket proxy stopped"))?;
Ok(())
}
pub(crate) fn resume(&mut self) -> anyhow::Result<()> {
self.resume_tx
.take()
.ok_or_else(|| anyhow!("disconnectable websocket proxy is already resumed"))?
.send(())
.map_err(|_| anyhow!("disconnectable websocket proxy stopped"))?;
Ok(())
}
}
async fn run_disconnectable_proxy(
listener: TcpListener,
upstream: String,
pause_rx: oneshot::Receiver<()>,
blocked_connection_tx: oneshot::Sender<()>,
mut resume_rx: oneshot::Receiver<()>,
) {
let Ok((mut downstream, _)) = listener.accept().await else {
return;
};
let Ok(mut upstream_stream) = TcpStream::connect(&upstream).await else {
return;
};
tokio::select! {
_ = copy_bidirectional(&mut downstream, &mut upstream_stream) => return,
_ = pause_rx => {}
}
drop(downstream);
drop(upstream_stream);
let mut blocked_connection_tx = Some(blocked_connection_tx);
loop {
tokio::select! {
_ = &mut resume_rx => break,
accepted = listener.accept() => {
let Ok((blocked, _)) = accepted else {
break;
};
drop(blocked);
if let Some(blocked_connection_tx) = blocked_connection_tx.take() {
let _ = blocked_connection_tx.send(());
}
}
}
}
loop {
let Ok((mut downstream, _)) = listener.accept().await else {
return;
};
let Ok(mut upstream_stream) = TcpStream::connect(&upstream).await else {
continue;
};
let _ = copy_bidirectional(&mut downstream, &mut upstream_stream).await;
}
}
async fn connect_websocket_when_ready(
websocket_url: &str,
) -> anyhow::Result<(
+121 -62
View File
@@ -1,5 +1,6 @@
mod common;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Context;
@@ -21,6 +22,7 @@ use tempfile::TempDir;
use test_case::test_case;
use tokio::sync::watch;
use tokio::time::Duration;
use tokio::time::sleep;
use tokio::time::timeout;
use common::DELAYED_OUTPUT_AFTER_EXIT_PARENT_ARG;
@@ -30,7 +32,7 @@ use common::exec_server::exec_server;
struct ProcessContext {
backend: Arc<dyn ExecBackend>,
server: Option<ExecServerHarness>,
_server: Option<ExecServerHarness>,
}
#[derive(Debug, PartialEq, Eq)]
@@ -55,13 +57,13 @@ async fn create_process_context(use_remote: bool) -> Result<ProcessContext> {
let environment = Environment::create_for_tests(Some(server.websocket_url().to_string()))?;
Ok(ProcessContext {
backend: environment.get_exec_backend(),
server: Some(server),
_server: Some(server),
})
} else {
let environment = Environment::create_for_tests(/*exec_server_url*/ None)?;
Ok(ProcessContext {
backend: environment.get_exec_backend(),
server: None,
_server: None,
})
}
}
@@ -634,88 +636,145 @@ async fn assert_exec_process_preserves_queued_events_before_subscribe(
#[cfg_attr(not(unix), ignore = "Unix-only exec-server process test")]
// Serialize tests that launch a real exec-server process through the full CLI.
#[serial_test::serial(remote_exec_server)]
async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
let mut context = create_process_context(/*use_remote*/ true).await?;
let session = context
.backend
async fn remote_exec_process_recovers_after_transport_disconnect() -> Result<()> {
let server = exec_server().await?;
let mut proxy = server.disconnectable_websocket_proxy().await?;
let environment = Environment::create_for_tests(Some(proxy.websocket_url().to_string()))?;
let backend = environment.get_exec_backend();
let temp_dir = TempDir::new()?;
let gate_path = temp_dir.path().join("release-output");
let emitted_path = temp_dir.path().join("output-emitted");
let session = backend
.start(ExecParams {
process_id: ProcessId::from("proc-disconnect"),
process_id: ProcessId::from("proc-recover"),
argv: vec![
"/bin/sh".to_string(),
"-c".to_string(),
"sleep 10".to_string(),
concat!(
"printf 'ready:%s\\n' \"$$\"; ",
"while [ ! -f \"$GATE\" ]; do /bin/sleep 0.01; done; ",
"printf 'during:%s\\n' \"$$\"; ",
": > \"$EMITTED\"; ",
"IFS= read -r line; ",
"printf 'after:%s:%s\\n' \"$$\" \"$line\"; ",
"exit 7",
)
.to_string(),
],
cwd: PathUri::from_path(std::env::current_dir()?)?,
env_policy: /*env_policy*/ None,
env: Default::default(),
env: HashMap::from([
(
"GATE".to_string(),
gate_path.to_string_lossy().into_owned(),
),
(
"EMITTED".to_string(),
emitted_path.to_string_lossy().into_owned(),
),
]),
tty: false,
pipe_stdin: false,
pipe_stdin: true,
arg0: None,
})
.await?;
let process = Arc::clone(&session.process);
let mut events = process.subscribe_events();
let process_for_pending_read = Arc::clone(&process);
let pending_read = tokio::spawn(async move {
process_for_pending_read
let mut output = Vec::new();
let mut last_seq = 0;
while !output.ends_with(b"\n") {
match timeout(Duration::from_secs(5), events.recv()).await?? {
ExecProcessEvent::Output(chunk) => {
assert_eq!(chunk.seq, last_seq + 1);
last_seq = chunk.seq;
output.extend_from_slice(&chunk.chunk.into_inner());
}
event => anyhow::bail!("expected ready output before disconnect, got {event:?}"),
}
}
let ready = String::from_utf8(output.clone())?;
let pid = ready
.strip_prefix("ready:")
.and_then(|line| line.strip_suffix('\n'))
.context("ready output should contain the process id")?
.to_string();
proxy.pause_and_disconnect().await?;
tokio::fs::write(&gate_path, b"").await?;
timeout(Duration::from_secs(5), async {
while tokio::fs::metadata(&emitted_path).await.is_err() {
sleep(Duration::from_millis(10)).await;
}
})
.await
.context("process did not emit output while disconnected")?;
let process_for_read = Arc::clone(&process);
let mut pending_read = tokio::spawn(async move {
process_for_read
.read(
/*after_seq*/ None,
/*after_seq*/ Some(last_seq),
/*max_bytes*/ None,
/*wait_ms*/ Some(60_000),
/*wait_ms*/ Some(0),
)
.await
});
let server = context
.server
.as_mut()
.expect("remote context should include exec-server harness");
server.shutdown().await?;
let event = timeout(Duration::from_secs(2), events.recv()).await??;
let ExecProcessEvent::Failed(event_message) = event else {
anyhow::bail!("expected process failure event, got {event:?}");
};
assert!(
event_message.starts_with("exec-server transport disconnected"),
"unexpected failure event: {event_message}"
timeout(Duration::from_millis(200), &mut pending_read)
.await
.is_err(),
"process reads should wait while recovery is in progress"
);
proxy.resume()?;
let recovered_read = timeout(Duration::from_secs(5), pending_read)
.await
.context("timed out waiting for a read after recovery")??;
let recovered_read = recovered_read?;
assert_eq!(recovered_read.failure, None);
let recovered_output = recovered_read
.chunks
.into_iter()
.flat_map(|chunk| chunk.chunk.into_inner())
.collect::<Vec<_>>();
assert_eq!(
String::from_utf8(recovered_output)?,
format!("during:{pid}\n")
);
let pending_response = timeout(Duration::from_secs(2), pending_read).await???;
let pending_message = pending_response
.failure
.expect("pending read should surface disconnect as a failure");
assert!(
pending_message.starts_with("exec-server transport disconnected"),
"unexpected pending failure message: {pending_message}"
);
let write = timeout(Duration::from_secs(5), process.write(b"hello\n".to_vec()))
.await
.context("timed out waiting for a write after recovery")??;
assert_eq!(write.status, WriteStatus::Accepted);
let mut wake_rx = process.subscribe_wake();
let response = read_process_until_change(process, &mut wake_rx, /*after_seq*/ None).await?;
let message = response
.failure
.expect("disconnect should surface as a failure");
assert!(
message.starts_with("exec-server transport disconnected"),
"unexpected failure message: {message}"
);
assert!(
response.closed,
"disconnect should close the process session"
);
let write_result = timeout(
Duration::from_secs(2),
session.process.write(b"hello".to_vec()),
)
.await
.context("timed out waiting for write after disconnect")?;
let write_error = write_result.expect_err("write after disconnect should fail");
assert!(
write_error
.to_string()
.starts_with("exec-server transport disconnected"),
"unexpected write error: {write_error}"
let mut saw_exit = false;
loop {
match timeout(Duration::from_secs(5), events.recv()).await?? {
ExecProcessEvent::Output(chunk) => {
assert_eq!(chunk.seq, last_seq + 1);
last_seq = chunk.seq;
output.extend_from_slice(&chunk.chunk.into_inner());
}
ExecProcessEvent::Exited { seq, exit_code } => {
assert_eq!(seq, last_seq + 1);
assert_eq!(exit_code, 7);
last_seq = seq;
saw_exit = true;
}
ExecProcessEvent::Closed { seq } => {
assert!(saw_exit, "closed must be delivered after exit");
assert_eq!(seq, last_seq + 1);
break;
}
ExecProcessEvent::Failed(message) => {
anyhow::bail!("process recovery failed: {message}");
}
}
}
assert_eq!(
String::from_utf8(output)?,
format!("ready:{pid}\nduring:{pid}\nafter:{pid}:hello\n")
);
Ok(())