mirror of
https://github.com/pchuan98/codex.git
synced 2026-07-01 00:31:56 +08:00
Recover exec process stdin writes (#28895)
## Summary Remote stdio MCP servers send tool calls by writing JSON-RPC bytes through `process/write`. When the exec-server websocket drops at the wrong time, the remote process can survive session recovery, but the stdin write can still fail back to RMCP as a transport send error. RMCP then closes the stdio MCP transport, so tools like `node_repl` are lost even though the process/session recovery path is working. This changes `process/write` to be safe to retry across exec-server recovery: - adds a required `writeId` to `process/write` - retries remote `Session::write` with the same `writeId` after reconnect - remembers accepted write ids per process so duplicate retries return `Accepted` without writing the same bytes to child stdin again - covers both the client retry path and server-side write id dedupe with tests In simple terms: ```text before: write to MCP stdin -> websocket closes -> write errors -> RMCP closes node_repl after: write to MCP stdin -> websocket closes -> reconnect -> retry same writeId server either writes once or recognizes it already did ```
This commit is contained in:
@@ -155,6 +155,7 @@ pub(crate) struct SessionState {
|
||||
events: ExecProcessEventLog,
|
||||
ordered_events: StdMutex<OrderedSessionEvents>,
|
||||
recoverable: AtomicBool,
|
||||
next_write_id: AtomicU64,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -421,12 +422,14 @@ impl ExecServerClient {
|
||||
&self,
|
||||
process_id: &ProcessId,
|
||||
chunk: Vec<u8>,
|
||||
write_id: String,
|
||||
) -> Result<WriteResponse, ExecServerError> {
|
||||
self.call(
|
||||
EXEC_WRITE_METHOD,
|
||||
&WriteParams {
|
||||
process_id: process_id.clone(),
|
||||
chunk: chunk.into(),
|
||||
write_id,
|
||||
},
|
||||
)
|
||||
.await
|
||||
@@ -730,6 +733,7 @@ impl SessionState {
|
||||
),
|
||||
ordered_events: StdMutex::new(OrderedSessionEvents::default()),
|
||||
recoverable: AtomicBool::new(recoverable),
|
||||
next_write_id: AtomicU64::new(1),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -829,6 +833,12 @@ impl SessionState {
|
||||
failure: Some(message),
|
||||
}
|
||||
}
|
||||
|
||||
fn next_write_id(&self) -> String {
|
||||
self.next_write_id
|
||||
.fetch_add(1, Ordering::Relaxed)
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl Session {
|
||||
@@ -885,7 +895,22 @@ impl Session {
|
||||
}
|
||||
|
||||
pub(crate) async fn write(&self, chunk: Vec<u8>) -> Result<WriteResponse, ExecServerError> {
|
||||
self.client.write(&self.process_id, chunk).await
|
||||
let write_id = self.state.next_write_id();
|
||||
loop {
|
||||
match self
|
||||
.client
|
||||
.write(&self.process_id, chunk.clone(), write_id.clone())
|
||||
.await
|
||||
{
|
||||
Ok(response) => return Ok(response),
|
||||
Err(error)
|
||||
if is_transport_closed_error(&error) && !self.client.inner.is_failed() =>
|
||||
{
|
||||
continue;
|
||||
}
|
||||
Err(error) => return Err(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn signal(&self, signal: ProcessSignal) -> Result<(), ExecServerError> {
|
||||
@@ -1110,6 +1135,8 @@ mod tests {
|
||||
use crate::protocol::EXEC_CLOSED_METHOD;
|
||||
use crate::protocol::EXEC_EXITED_METHOD;
|
||||
use crate::protocol::EXEC_OUTPUT_DELTA_METHOD;
|
||||
use crate::protocol::EXEC_READ_METHOD;
|
||||
use crate::protocol::EXEC_WRITE_METHOD;
|
||||
use crate::protocol::ExecClosedNotification;
|
||||
use crate::protocol::ExecExitedNotification;
|
||||
use crate::protocol::ExecOutputDeltaNotification;
|
||||
@@ -1118,6 +1145,10 @@ mod tests {
|
||||
use crate::protocol::INITIALIZED_METHOD;
|
||||
use crate::protocol::InitializeResponse;
|
||||
use crate::protocol::ProcessOutputChunk;
|
||||
use crate::protocol::ReadResponse;
|
||||
use crate::protocol::WriteParams;
|
||||
use crate::protocol::WriteResponse;
|
||||
use crate::protocol::WriteStatus;
|
||||
|
||||
async fn read_jsonrpc_line<R>(lines: &mut tokio::io::Lines<BufReader<R>>) -> JSONRPCMessage
|
||||
where
|
||||
@@ -1685,6 +1716,121 @@ mod tests {
|
||||
server.await.expect("server task should finish");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn session_write_retries_same_write_id_after_recovery() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("listener should bind");
|
||||
let websocket_url = format!(
|
||||
"ws://{}",
|
||||
listener.local_addr().expect("listener should have address")
|
||||
);
|
||||
let (finish_tx, finish_rx) = oneshot::channel();
|
||||
let server = tokio::spawn(async move {
|
||||
let mut first = accept_websocket(&listener).await;
|
||||
complete_websocket_initialize(
|
||||
&mut first,
|
||||
"session-1",
|
||||
/*expected_resume_session_id*/ None,
|
||||
)
|
||||
.await;
|
||||
|
||||
let first_write = read_jsonrpc_websocket(&mut first).await;
|
||||
let first_write = match first_write {
|
||||
JSONRPCMessage::Request(request) if request.method == EXEC_WRITE_METHOD => request,
|
||||
other => panic!("expected first process/write request, got {other:?}"),
|
||||
};
|
||||
let first_write_params: WriteParams =
|
||||
serde_json::from_value(first_write.params.expect("write params should exist"))
|
||||
.expect("write params should deserialize");
|
||||
assert_eq!(first_write_params.process_id.as_str(), "proc-write");
|
||||
assert_eq!(first_write_params.chunk.into_inner(), b"hello\n".to_vec());
|
||||
let write_id = first_write_params.write_id;
|
||||
assert!(!write_id.is_empty());
|
||||
drop(first);
|
||||
|
||||
let mut resumed = accept_websocket(&listener).await;
|
||||
complete_websocket_initialize(
|
||||
&mut resumed,
|
||||
"session-1",
|
||||
/*expected_resume_session_id*/ Some("session-1"),
|
||||
)
|
||||
.await;
|
||||
|
||||
let recovery_read = read_jsonrpc_websocket(&mut resumed).await;
|
||||
let recovery_read = match recovery_read {
|
||||
JSONRPCMessage::Request(request) if request.method == EXEC_READ_METHOD => request,
|
||||
other => panic!("expected recovery process/read request, got {other:?}"),
|
||||
};
|
||||
write_jsonrpc_websocket(
|
||||
&mut resumed,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: recovery_read.id,
|
||||
result: serde_json::to_value(ReadResponse {
|
||||
chunks: Vec::new(),
|
||||
next_seq: 1,
|
||||
exited: false,
|
||||
exit_code: None,
|
||||
closed: false,
|
||||
failure: None,
|
||||
})
|
||||
.expect("read response should serialize"),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let retried_write = read_jsonrpc_websocket(&mut resumed).await;
|
||||
let retried_write = match retried_write {
|
||||
JSONRPCMessage::Request(request) if request.method == EXEC_WRITE_METHOD => request,
|
||||
other => panic!("expected retried process/write request, got {other:?}"),
|
||||
};
|
||||
let retried_write_params: WriteParams =
|
||||
serde_json::from_value(retried_write.params.expect("write params should exist"))
|
||||
.expect("write params should deserialize");
|
||||
assert_eq!(retried_write_params.process_id.as_str(), "proc-write");
|
||||
assert_eq!(retried_write_params.chunk.into_inner(), b"hello\n".to_vec());
|
||||
assert_eq!(retried_write_params.write_id, write_id);
|
||||
write_jsonrpc_websocket(
|
||||
&mut resumed,
|
||||
JSONRPCMessage::Response(JSONRPCResponse {
|
||||
id: retried_write.id,
|
||||
result: serde_json::to_value(WriteResponse {
|
||||
status: WriteStatus::Accepted,
|
||||
})
|
||||
.expect("write response should serialize"),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
finish_rx.await.expect("test should finish");
|
||||
});
|
||||
|
||||
let client = LazyRemoteExecServerClient::new(ExecServerTransportParams::WebSocketUrl {
|
||||
websocket_url,
|
||||
connect_timeout: Duration::from_secs(1),
|
||||
initialize_timeout: Duration::from_secs(1),
|
||||
});
|
||||
let stable_client = client.get().await.expect("client should connect");
|
||||
let session = stable_client
|
||||
.register_session(&ProcessId::from("proc-write"))
|
||||
.await
|
||||
.expect("session should register");
|
||||
|
||||
let response = timeout(Duration::from_secs(2), session.write(b"hello\n".to_vec()))
|
||||
.await
|
||||
.expect("write should not time out")
|
||||
.expect("write should recover");
|
||||
assert_eq!(
|
||||
response,
|
||||
WriteResponse {
|
||||
status: WriteStatus::Accepted
|
||||
}
|
||||
);
|
||||
|
||||
finish_tx.send(()).expect("test should finish");
|
||||
server.await.expect("server task should finish");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn explicit_resume_drains_notifications_before_initialize_response() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::collections::VecDeque;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
@@ -54,6 +57,8 @@ use crate::rpc::invalid_request;
|
||||
const RETAINED_OUTPUT_BYTES_PER_PROCESS: usize = 1024 * 1024;
|
||||
const NOTIFICATION_CHANNEL_CAPACITY: usize = 256;
|
||||
const PROCESS_EVENT_CHANNEL_CAPACITY: usize = 256;
|
||||
const RETAINED_STDIN_WRITE_IDS_PER_PROCESS: usize = 4096;
|
||||
static NEXT_LOCAL_STDIN_WRITE_ID: AtomicU64 = AtomicU64::new(1);
|
||||
#[cfg(test)]
|
||||
const EXITED_PROCESS_RETENTION: Duration = Duration::from_millis(25);
|
||||
#[cfg(not(test))]
|
||||
@@ -70,6 +75,7 @@ struct RunningProcess {
|
||||
session: ExecCommandSession,
|
||||
tty: bool,
|
||||
pipe_stdin: bool,
|
||||
accepted_stdin_write_ids: Arc<Mutex<AcceptedStdinWriteIds>>,
|
||||
output: VecDeque<RetainedOutputChunk>,
|
||||
retained_bytes: usize,
|
||||
next_seq: u64,
|
||||
@@ -81,6 +87,37 @@ struct RunningProcess {
|
||||
closed: bool,
|
||||
}
|
||||
|
||||
/// Bounded cache of stdin write ids that have already been accepted for one process.
|
||||
///
|
||||
/// A remote client can retry `process/write` after reconnecting. Remembering accepted
|
||||
/// ids lets the server acknowledge the retried request without writing the same bytes
|
||||
/// to child stdin twice.
|
||||
#[derive(Default)]
|
||||
struct AcceptedStdinWriteIds {
|
||||
ids: HashSet<String>,
|
||||
order: VecDeque<String>,
|
||||
}
|
||||
|
||||
impl AcceptedStdinWriteIds {
|
||||
fn contains(&self, write_id: &str) -> bool {
|
||||
self.ids.contains(write_id)
|
||||
}
|
||||
|
||||
fn remember(&mut self, write_id: String) {
|
||||
if !self.ids.insert(write_id.clone()) {
|
||||
return;
|
||||
}
|
||||
|
||||
self.order.push_back(write_id);
|
||||
while self.order.len() > RETAINED_STDIN_WRITE_IDS_PER_PROCESS {
|
||||
let Some(evicted) = self.order.pop_front() else {
|
||||
break;
|
||||
};
|
||||
self.ids.remove(&evicted);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ProcessStart;
|
||||
|
||||
enum ProcessEntry {
|
||||
@@ -247,6 +284,9 @@ impl LocalProcess {
|
||||
session: spawned.session,
|
||||
tty: params.tty,
|
||||
pipe_stdin: params.pipe_stdin,
|
||||
accepted_stdin_write_ids: Arc::new(
|
||||
Mutex::new(AcceptedStdinWriteIds::default()),
|
||||
),
|
||||
output: VecDeque::new(),
|
||||
retained_bytes: 0,
|
||||
next_seq: 1,
|
||||
@@ -383,7 +423,11 @@ impl LocalProcess {
|
||||
params: WriteParams,
|
||||
) -> Result<WriteResponse, JSONRPCErrorError> {
|
||||
let _input_bytes = params.chunk.0.len();
|
||||
let writer_tx = {
|
||||
if params.write_id.is_empty() {
|
||||
return Err(invalid_params("writeId must not be empty".to_string()));
|
||||
}
|
||||
|
||||
let (writer_tx, accepted_stdin_write_ids) = {
|
||||
let process_map = self.inner.processes.lock().await;
|
||||
let Some(process) = process_map.get(¶ms.process_id) else {
|
||||
return Ok(WriteResponse {
|
||||
@@ -400,13 +444,37 @@ impl LocalProcess {
|
||||
status: WriteStatus::StdinClosed,
|
||||
});
|
||||
}
|
||||
process.session.writer_sender()
|
||||
(
|
||||
process.session.writer_sender(),
|
||||
Arc::clone(&process.accepted_stdin_write_ids),
|
||||
)
|
||||
};
|
||||
|
||||
writer_tx
|
||||
.send(params.chunk.into_inner())
|
||||
if accepted_stdin_write_ids
|
||||
.lock()
|
||||
.await
|
||||
.contains(¶ms.write_id)
|
||||
{
|
||||
return Ok(WriteResponse {
|
||||
status: WriteStatus::Accepted,
|
||||
});
|
||||
}
|
||||
|
||||
let permit = writer_tx
|
||||
.reserve()
|
||||
.await
|
||||
.map_err(|_| internal_error("failed to write to process stdin".to_string()))?;
|
||||
let mut accepted_stdin_write_ids = accepted_stdin_write_ids.lock().await;
|
||||
if accepted_stdin_write_ids.contains(¶ms.write_id) {
|
||||
return Ok(WriteResponse {
|
||||
status: WriteStatus::Accepted,
|
||||
});
|
||||
}
|
||||
|
||||
// After this synchronous send, record the write id before any further await.
|
||||
// Otherwise a cancelled RPC handler could retry and write the same bytes again.
|
||||
permit.send(params.chunk.into_inner());
|
||||
accepted_stdin_write_ids.remember(params.write_id);
|
||||
|
||||
Ok(WriteResponse {
|
||||
status: WriteStatus::Accepted,
|
||||
@@ -601,6 +669,10 @@ impl LocalProcess {
|
||||
self.exec_write(WriteParams {
|
||||
process_id: process_id.clone(),
|
||||
chunk: chunk.into(),
|
||||
write_id: format!(
|
||||
"local-{}",
|
||||
NEXT_LOCAL_STDIN_WRITE_ID.fetch_add(1, Ordering::Relaxed)
|
||||
),
|
||||
})
|
||||
.await
|
||||
.map_err(map_handler_error)
|
||||
@@ -1023,6 +1095,7 @@ mod tests {
|
||||
session: dummy_session(),
|
||||
tty: false,
|
||||
pipe_stdin: false,
|
||||
accepted_stdin_write_ids: Arc::new(Mutex::new(AcceptedStdinWriteIds::default())),
|
||||
output: VecDeque::new(),
|
||||
retained_bytes: 0,
|
||||
next_seq: 1,
|
||||
|
||||
@@ -154,6 +154,7 @@ pub struct ReadResponse {
|
||||
pub struct WriteParams {
|
||||
pub process_id: ProcessId,
|
||||
pub chunk: ByteChunk,
|
||||
pub write_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
|
||||
@@ -144,7 +144,8 @@ async fn exec_server_defaults_omitted_pipe_stdin_to_closed_stdin() -> anyhow::Re
|
||||
"process/write",
|
||||
serde_json::json!({
|
||||
"processId": "proc-default-stdin",
|
||||
"chunk": "aWdub3JlZAo="
|
||||
"chunk": "aWdub3JlZAo=",
|
||||
"writeId": "write-default-stdin"
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
@@ -171,6 +172,140 @@ async fn exec_server_defaults_omitted_pipe_stdin_to_closed_stdin() -> anyhow::Re
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exec_server_dedupes_retried_process_write_ids() -> anyhow::Result<()> {
|
||||
let mut server = exec_server().await?;
|
||||
let initialize_id = server
|
||||
.send_request(
|
||||
"initialize",
|
||||
serde_json::to_value(InitializeParams {
|
||||
client_name: "exec-server-test".to_string(),
|
||||
resume_session_id: None,
|
||||
})?,
|
||||
)
|
||||
.await?;
|
||||
let _ = server
|
||||
.wait_for_event(|event| {
|
||||
matches!(
|
||||
event,
|
||||
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &initialize_id
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
server
|
||||
.send_notification("initialized", serde_json::json!({}))
|
||||
.await?;
|
||||
|
||||
let process_start_id = server
|
||||
.send_request(
|
||||
"process/start",
|
||||
serde_json::json!({
|
||||
"processId": "proc-write-id",
|
||||
"argv": [
|
||||
"/bin/sh",
|
||||
"-c",
|
||||
"IFS= read -r first; printf 'line:%s\\n' \"$first\"; IFS= read -r second; printf 'line:%s\\n' \"$second\""
|
||||
],
|
||||
"cwd": std::env::current_dir()?,
|
||||
"env": {},
|
||||
"tty": false,
|
||||
"pipeStdin": true,
|
||||
"arg0": null
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
let _ = server
|
||||
.wait_for_event(|event| {
|
||||
matches!(
|
||||
event,
|
||||
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &process_start_id
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
for (write_id, chunk) in [
|
||||
("write-1", "Zmlyc3QK"),
|
||||
("write-1", "Zmlyc3QK"),
|
||||
("write-2", "c2Vjb25kCg=="),
|
||||
] {
|
||||
let request_id = server
|
||||
.send_request(
|
||||
"process/write",
|
||||
serde_json::json!({
|
||||
"processId": "proc-write-id",
|
||||
"chunk": chunk,
|
||||
"writeId": write_id
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
let response = server
|
||||
.wait_for_event(|event| {
|
||||
matches!(
|
||||
event,
|
||||
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &request_id
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
let JSONRPCMessage::Response(JSONRPCResponse { result, .. }) = response else {
|
||||
panic!("expected process/write response");
|
||||
};
|
||||
let write_response: WriteResponse = serde_json::from_value(result)?;
|
||||
assert_eq!(
|
||||
write_response,
|
||||
WriteResponse {
|
||||
status: WriteStatus::Accepted
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
let mut after_seq = None;
|
||||
let mut output = Vec::new();
|
||||
for _ in 0..5 {
|
||||
let read_id = server
|
||||
.send_request(
|
||||
"process/read",
|
||||
serde_json::json!({
|
||||
"processId": "proc-write-id",
|
||||
"afterSeq": after_seq,
|
||||
"maxBytes": null,
|
||||
"waitMs": 1000
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
let response = server
|
||||
.wait_for_event(|event| {
|
||||
matches!(
|
||||
event,
|
||||
JSONRPCMessage::Response(JSONRPCResponse { id, .. }) if id == &read_id
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
let JSONRPCMessage::Response(JSONRPCResponse { result, .. }) = response else {
|
||||
panic!("expected process/read response");
|
||||
};
|
||||
let read_response: ReadResponse = serde_json::from_value(result)?;
|
||||
output.extend(
|
||||
read_response
|
||||
.chunks
|
||||
.into_iter()
|
||||
.flat_map(|chunk| chunk.chunk.into_inner()),
|
||||
);
|
||||
after_seq = Some(read_response.next_seq.saturating_sub(1));
|
||||
if read_response.closed || output.ends_with(b"line:second\n") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
String::from_utf8(output)?,
|
||||
"line:first\nline:second\n".to_string()
|
||||
);
|
||||
|
||||
server.shutdown().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn exec_server_resumes_detached_session_without_killing_processes() -> anyhow::Result<()> {
|
||||
let mut server = exec_server().await?;
|
||||
|
||||
Reference in New Issue
Block a user