Recover exec process stdin writes (#28895)

## Summary

Remote stdio MCP servers send tool calls by writing JSON-RPC bytes
through `process/write`.

When the exec-server websocket drops at the wrong time, the remote
process can survive session recovery, but the stdin write can still fail
back to RMCP as a transport send error. RMCP then closes the stdio MCP
transport, so tools like `node_repl` are lost even though the
process/session recovery path is working.

This changes `process/write` to be safe to retry across exec-server
recovery:

- adds a required `writeId` to `process/write`
- retries remote `Session::write` with the same `writeId` after
reconnect
- remembers accepted write ids per process so duplicate retries return
`Accepted` without writing the same bytes to child stdin again
- covers both the client retry path and server-side write id dedupe with
tests

In simple terms:

```text
before:
write to MCP stdin -> websocket closes -> write errors -> RMCP closes node_repl

after:
write to MCP stdin -> websocket closes -> reconnect -> retry same writeId
server either writes once or recognizes it already did
```
This commit is contained in:
jif
2026-06-18 18:04:26 +01:00
committed by GitHub
Unverified
parent 07298a948c
commit 83e6a786a2
4 changed files with 361 additions and 6 deletions
+147 -1
View File
@@ -155,6 +155,7 @@ pub(crate) struct SessionState {
events: ExecProcessEventLog,
ordered_events: StdMutex<OrderedSessionEvents>,
recoverable: AtomicBool,
next_write_id: AtomicU64,
}
#[derive(Default)]
@@ -421,12 +422,14 @@ impl ExecServerClient {
&self,
process_id: &ProcessId,
chunk: Vec<u8>,
write_id: String,
) -> Result<WriteResponse, ExecServerError> {
self.call(
EXEC_WRITE_METHOD,
&WriteParams {
process_id: process_id.clone(),
chunk: chunk.into(),
write_id,
},
)
.await
@@ -730,6 +733,7 @@ impl SessionState {
),
ordered_events: StdMutex::new(OrderedSessionEvents::default()),
recoverable: AtomicBool::new(recoverable),
next_write_id: AtomicU64::new(1),
}
}
@@ -829,6 +833,12 @@ impl SessionState {
failure: Some(message),
}
}
fn next_write_id(&self) -> String {
self.next_write_id
.fetch_add(1, Ordering::Relaxed)
.to_string()
}
}
impl Session {
@@ -885,7 +895,22 @@ impl Session {
}
pub(crate) async fn write(&self, chunk: Vec<u8>) -> Result<WriteResponse, ExecServerError> {
self.client.write(&self.process_id, chunk).await
let write_id = self.state.next_write_id();
loop {
match self
.client
.write(&self.process_id, chunk.clone(), write_id.clone())
.await
{
Ok(response) => return Ok(response),
Err(error)
if is_transport_closed_error(&error) && !self.client.inner.is_failed() =>
{
continue;
}
Err(error) => return Err(error),
}
}
}
pub(crate) async fn signal(&self, signal: ProcessSignal) -> Result<(), ExecServerError> {
@@ -1110,6 +1135,8 @@ mod tests {
use crate::protocol::EXEC_CLOSED_METHOD;
use crate::protocol::EXEC_EXITED_METHOD;
use crate::protocol::EXEC_OUTPUT_DELTA_METHOD;
use crate::protocol::EXEC_READ_METHOD;
use crate::protocol::EXEC_WRITE_METHOD;
use crate::protocol::ExecClosedNotification;
use crate::protocol::ExecExitedNotification;
use crate::protocol::ExecOutputDeltaNotification;
@@ -1118,6 +1145,10 @@ mod tests {
use crate::protocol::INITIALIZED_METHOD;
use crate::protocol::InitializeResponse;
use crate::protocol::ProcessOutputChunk;
use crate::protocol::ReadResponse;
use crate::protocol::WriteParams;
use crate::protocol::WriteResponse;
use crate::protocol::WriteStatus;
async fn read_jsonrpc_line<R>(lines: &mut tokio::io::Lines<BufReader<R>>) -> JSONRPCMessage
where
@@ -1685,6 +1716,121 @@ mod tests {
server.await.expect("server task should finish");
}
#[tokio::test]
async fn session_write_retries_same_write_id_after_recovery() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let websocket_url = format!(
"ws://{}",
listener.local_addr().expect("listener should have address")
);
let (finish_tx, finish_rx) = oneshot::channel();
let server = tokio::spawn(async move {
let mut first = accept_websocket(&listener).await;
complete_websocket_initialize(
&mut first,
"session-1",
/*expected_resume_session_id*/ None,
)
.await;
let first_write = read_jsonrpc_websocket(&mut first).await;
let first_write = match first_write {
JSONRPCMessage::Request(request) if request.method == EXEC_WRITE_METHOD => request,
other => panic!("expected first process/write request, got {other:?}"),
};
let first_write_params: WriteParams =
serde_json::from_value(first_write.params.expect("write params should exist"))
.expect("write params should deserialize");
assert_eq!(first_write_params.process_id.as_str(), "proc-write");
assert_eq!(first_write_params.chunk.into_inner(), b"hello\n".to_vec());
let write_id = first_write_params.write_id;
assert!(!write_id.is_empty());
drop(first);
let mut resumed = accept_websocket(&listener).await;
complete_websocket_initialize(
&mut resumed,
"session-1",
/*expected_resume_session_id*/ Some("session-1"),
)
.await;
let recovery_read = read_jsonrpc_websocket(&mut resumed).await;
let recovery_read = match recovery_read {
JSONRPCMessage::Request(request) if request.method == EXEC_READ_METHOD => request,
other => panic!("expected recovery process/read request, got {other:?}"),
};
write_jsonrpc_websocket(
&mut resumed,
JSONRPCMessage::Response(JSONRPCResponse {
id: recovery_read.id,
result: serde_json::to_value(ReadResponse {
chunks: Vec::new(),
next_seq: 1,
exited: false,
exit_code: None,
closed: false,
failure: None,
})
.expect("read response should serialize"),
}),
)
.await;
let retried_write = read_jsonrpc_websocket(&mut resumed).await;
let retried_write = match retried_write {
JSONRPCMessage::Request(request) if request.method == EXEC_WRITE_METHOD => request,
other => panic!("expected retried process/write request, got {other:?}"),
};
let retried_write_params: WriteParams =
serde_json::from_value(retried_write.params.expect("write params should exist"))
.expect("write params should deserialize");
assert_eq!(retried_write_params.process_id.as_str(), "proc-write");
assert_eq!(retried_write_params.chunk.into_inner(), b"hello\n".to_vec());
assert_eq!(retried_write_params.write_id, write_id);
write_jsonrpc_websocket(
&mut resumed,
JSONRPCMessage::Response(JSONRPCResponse {
id: retried_write.id,
result: serde_json::to_value(WriteResponse {
status: WriteStatus::Accepted,
})
.expect("write response should serialize"),
}),
)
.await;
finish_rx.await.expect("test should finish");
});
let client = LazyRemoteExecServerClient::new(ExecServerTransportParams::WebSocketUrl {
websocket_url,
connect_timeout: Duration::from_secs(1),
initialize_timeout: Duration::from_secs(1),
});
let stable_client = client.get().await.expect("client should connect");
let session = stable_client
.register_session(&ProcessId::from("proc-write"))
.await
.expect("session should register");
let response = timeout(Duration::from_secs(2), session.write(b"hello\n".to_vec()))
.await
.expect("write should not time out")
.expect("write should recover");
assert_eq!(
response,
WriteResponse {
status: WriteStatus::Accepted
}
);
finish_tx.send(()).expect("test should finish");
server.await.expect("server task should finish");
}
#[tokio::test]
async fn explicit_resume_drains_notifications_before_initialize_response() {
let listener = TcpListener::bind("127.0.0.1:0")
+77 -4
View File
@@ -1,7 +1,10 @@
use std::collections::HashMap;
use std::collections::HashSet;
use std::collections::VecDeque;
use std::collections::hash_map::Entry;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use std::time::Duration;
use codex_app_server_protocol::JSONRPCErrorError;
@@ -54,6 +57,8 @@ use crate::rpc::invalid_request;
const RETAINED_OUTPUT_BYTES_PER_PROCESS: usize = 1024 * 1024;
const NOTIFICATION_CHANNEL_CAPACITY: usize = 256;
const PROCESS_EVENT_CHANNEL_CAPACITY: usize = 256;
const RETAINED_STDIN_WRITE_IDS_PER_PROCESS: usize = 4096;
static NEXT_LOCAL_STDIN_WRITE_ID: AtomicU64 = AtomicU64::new(1);
#[cfg(test)]
const EXITED_PROCESS_RETENTION: Duration = Duration::from_millis(25);
#[cfg(not(test))]
@@ -70,6 +75,7 @@ struct RunningProcess {
session: ExecCommandSession,
tty: bool,
pipe_stdin: bool,
accepted_stdin_write_ids: Arc<Mutex<AcceptedStdinWriteIds>>,
output: VecDeque<RetainedOutputChunk>,
retained_bytes: usize,
next_seq: u64,
@@ -81,6 +87,37 @@ struct RunningProcess {
closed: bool,
}
/// Bounded cache of stdin write ids that have already been accepted for one process.
///
/// A remote client can retry `process/write` after reconnecting. Remembering accepted
/// ids lets the server acknowledge the retried request without writing the same bytes
/// to child stdin twice.
#[derive(Default)]
struct AcceptedStdinWriteIds {
ids: HashSet<String>,
order: VecDeque<String>,
}
impl AcceptedStdinWriteIds {
fn contains(&self, write_id: &str) -> bool {
self.ids.contains(write_id)
}
fn remember(&mut self, write_id: String) {
if !self.ids.insert(write_id.clone()) {
return;
}
self.order.push_back(write_id);
while self.order.len() > RETAINED_STDIN_WRITE_IDS_PER_PROCESS {
let Some(evicted) = self.order.pop_front() else {
break;
};
self.ids.remove(&evicted);
}
}
}
struct ProcessStart;
enum ProcessEntry {
@@ -247,6 +284,9 @@ impl LocalProcess {
session: spawned.session,
tty: params.tty,
pipe_stdin: params.pipe_stdin,
accepted_stdin_write_ids: Arc::new(
Mutex::new(AcceptedStdinWriteIds::default()),
),
output: VecDeque::new(),
retained_bytes: 0,
next_seq: 1,
@@ -383,7 +423,11 @@ impl LocalProcess {
params: WriteParams,
) -> Result<WriteResponse, JSONRPCErrorError> {
let _input_bytes = params.chunk.0.len();
let writer_tx = {
if params.write_id.is_empty() {
return Err(invalid_params("writeId must not be empty".to_string()));
}
let (writer_tx, accepted_stdin_write_ids) = {
let process_map = self.inner.processes.lock().await;
let Some(process) = process_map.get(&params.process_id) else {
return Ok(WriteResponse {
@@ -400,13 +444,37 @@ impl LocalProcess {
status: WriteStatus::StdinClosed,
});
}
process.session.writer_sender()
(
process.session.writer_sender(),
Arc::clone(&process.accepted_stdin_write_ids),
)
};
writer_tx
.send(params.chunk.into_inner())
if accepted_stdin_write_ids
.lock()
.await
.contains(&params.write_id)
{
return Ok(WriteResponse {
status: WriteStatus::Accepted,
});
}
let permit = writer_tx
.reserve()
.await
.map_err(|_| internal_error("failed to write to process stdin".to_string()))?;
let mut accepted_stdin_write_ids = accepted_stdin_write_ids.lock().await;
if accepted_stdin_write_ids.contains(&params.write_id) {
return Ok(WriteResponse {
status: WriteStatus::Accepted,
});
}
// After this synchronous send, record the write id before any further await.
// Otherwise a cancelled RPC handler could retry and write the same bytes again.
permit.send(params.chunk.into_inner());
accepted_stdin_write_ids.remember(params.write_id);
Ok(WriteResponse {
status: WriteStatus::Accepted,
@@ -601,6 +669,10 @@ impl LocalProcess {
self.exec_write(WriteParams {
process_id: process_id.clone(),
chunk: chunk.into(),
write_id: format!(
"local-{}",
NEXT_LOCAL_STDIN_WRITE_ID.fetch_add(1, Ordering::Relaxed)
),
})
.await
.map_err(map_handler_error)
@@ -1023,6 +1095,7 @@ mod tests {
session: dummy_session(),
tty: false,
pipe_stdin: false,
accepted_stdin_write_ids: Arc::new(Mutex::new(AcceptedStdinWriteIds::default())),
output: VecDeque::new(),
retained_bytes: 0,
next_seq: 1,
+1
View File
@@ -154,6 +154,7 @@ pub struct ReadResponse {
pub struct WriteParams {
pub process_id: ProcessId,
pub chunk: ByteChunk,
pub write_id: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
+136 -1
View File
@@ -144,7 +144,8 @@ async fn exec_server_defaults_omitted_pipe_stdin_to_closed_stdin() -> anyhow::Re
"process/write",
serde_json::json!({
"processId": "proc-default-stdin",
"chunk": "aWdub3JlZAo="
"chunk": "aWdub3JlZAo=",
"writeId": "write-default-stdin"
}),
)
.await?;
@@ -171,6 +172,140 @@ async fn exec_server_defaults_omitted_pipe_stdin_to_closed_stdin() -> anyhow::Re
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn exec_server_dedupes_retried_process_write_ids() -> anyhow::Result<()> {
let mut server = exec_server().await?;
let initialize_id = server
.send_request(
"initialize",
serde_json::to_value(InitializeParams {
client_name: "exec-server-test".to_string(),
resume_session_id: None,
})?,
)
.await?;
let _ = server
.wait_for_event(|event| {
matches!(
event,
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &initialize_id
)
})
.await?;
server
.send_notification("initialized", serde_json::json!({}))
.await?;
let process_start_id = server
.send_request(
"process/start",
serde_json::json!({
"processId": "proc-write-id",
"argv": [
"/bin/sh",
"-c",
"IFS= read -r first; printf 'line:%s\\n' \"$first\"; IFS= read -r second; printf 'line:%s\\n' \"$second\""
],
"cwd": std::env::current_dir()?,
"env": {},
"tty": false,
"pipeStdin": true,
"arg0": null
}),
)
.await?;
let _ = server
.wait_for_event(|event| {
matches!(
event,
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &process_start_id
)
})
.await?;
for (write_id, chunk) in [
("write-1", "Zmlyc3QK"),
("write-1", "Zmlyc3QK"),
("write-2", "c2Vjb25kCg=="),
] {
let request_id = server
.send_request(
"process/write",
serde_json::json!({
"processId": "proc-write-id",
"chunk": chunk,
"writeId": write_id
}),
)
.await?;
let response = server
.wait_for_event(|event| {
matches!(
event,
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &request_id
)
})
.await?;
let JSONRPCMessage::Response(JSONRPCResponse { result, .. }) = response else {
panic!("expected process/write response");
};
let write_response: WriteResponse = serde_json::from_value(result)?;
assert_eq!(
write_response,
WriteResponse {
status: WriteStatus::Accepted
}
);
}
let mut after_seq = None;
let mut output = Vec::new();
for _ in 0..5 {
let read_id = server
.send_request(
"process/read",
serde_json::json!({
"processId": "proc-write-id",
"afterSeq": after_seq,
"maxBytes": null,
"waitMs": 1000
}),
)
.await?;
let response = server
.wait_for_event(|event| {
matches!(
event,
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &read_id
)
})
.await?;
let JSONRPCMessage::Response(JSONRPCResponse { result, .. }) = response else {
panic!("expected process/read response");
};
let read_response: ReadResponse = serde_json::from_value(result)?;
output.extend(
read_response
.chunks
.into_iter()
.flat_map(|chunk| chunk.chunk.into_inner()),
);
after_seq = Some(read_response.next_seq.saturating_sub(1));
if read_response.closed || output.ends_with(b"line:second\n") {
break;
}
}
assert_eq!(
String::from_utf8(output)?,
"line:first\nline:second\n".to_string()
);
server.shutdown().await?;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn exec_server_resumes_detached_session_without_killing_processes() -> anyhow::Result<()> {
let mut server = exec_server().await?;