feat: exec-server prep for unified exec (#15691)

This PR partially rebase `unified_exec` on the `exec-server` and adapt
the `exec-server` accordingly.

## What changed in `exec-server`

1. Replaced the old "broadcast-driven; process-global" event model with
process-scoped session events. The goal is to be able to have dedicated
handler for each process.
2. Add to protocol contract to support explicit lifecycle status and
stream ordering:
- `WriteResponse` now returns `WriteStatus` (Accepted, UnknownProcess,
StdinClosed, Starting) instead of a bool.
  - Added seq fields to output/exited notifications.
  - Added terminal process/closed notification.
3. Demultiplexed remote notifications into per-process channels. Same as
for the event sys
4. Local and remote backends now both implement ExecBackend.
5. Local backend wraps internal process ID/operations into per-process
ExecProcess objects.
6. Remote backend registers a session channel before launch and
unregisters on failed launch.

## What changed in `unified_exec`

1. Added unified process-state model and backend-neutral process
wrapper. This will probably disappear in the future, but it makes it
easier to keep the work flowing on both side.
- `UnifiedExecProcess` now handles both local PTY sessions and remote
exec-server processes through a shared `ProcessHandle`.
- Added `ProcessState` to track has_exited, exit_code, and terminal
failure message consistently across backends.
2. Routed write and lifecycle handling through process-level methods.

## Some rationals

1. The change centralizes execution transport in exec-server while
preserving policy and orchestration ownership in core, avoiding
duplicated launch approval logic. This comes from internal discussion.
2. Session-scoped events remove coupling/cross-talk between processes
and make stream ordering and terminal state explicit (seq, closed,
failed).
3. The failure-path surfacing (remote launch failures, write failures,
transport disconnects) makes command tool output and cleanup behavior
deterministic

## Follow-ups:
* Unify the concept of thread ID behind an obfuscated struct
* FD handling
* Full zsh-fork compatibility
* Full network sandboxing compatibility
* Handle ws disconnection
This commit is contained in:
jif-oai
2026-03-26 14:22:34 +00:00
committed by GitHub
Unverified
parent 4a5635b5a0
commit 7dac332c93
24 changed files with 1933 additions and 325 deletions
+1 -1
View File
@@ -614,7 +614,7 @@
"anyhow_1.0.101": "{\"dependencies\":[{\"name\":\"backtrace\",\"optional\":true,\"req\":\"^0.3.51\"},{\"default_features\":false,\"kind\":\"dev\",\"name\":\"futures\",\"req\":\"^0.3\"},{\"kind\":\"dev\",\"name\":\"rustversion\",\"req\":\"^1.0.6\"},{\"features\":[\"full\"],\"kind\":\"dev\",\"name\":\"syn\",\"req\":\"^2.0\"},{\"kind\":\"dev\",\"name\":\"thiserror\",\"req\":\"^2\"},{\"features\":[\"diff\"],\"kind\":\"dev\",\"name\":\"trybuild\",\"req\":\"^1.0.108\"}],\"features\":{\"default\":[\"std\"],\"std\":[]}}",
"arbitrary_1.4.2": "{\"dependencies\":[{\"name\":\"derive_arbitrary\",\"optional\":true,\"req\":\"~1.4.0\"},{\"kind\":\"dev\",\"name\":\"exhaustigen\",\"req\":\"^0.1.0\"}],\"features\":{\"derive\":[\"derive_arbitrary\"]}}",
"arboard_3.6.1": "{\"dependencies\":[{\"features\":[\"std\"],\"name\":\"clipboard-win\",\"req\":\"^5.3.1\",\"target\":\"cfg(windows)\"},{\"kind\":\"dev\",\"name\":\"env_logger\",\"req\":\"^0.10.2\"},{\"default_features\":false,\"features\":[\"png\"],\"name\":\"image\",\"optional\":true,\"req\":\"^0.25\",\"target\":\"cfg(all(unix, not(any(target_os=\\\"macos\\\", target_os=\\\"android\\\", target_os=\\\"emscripten\\\"))))\"},{\"default_features\":false,\"features\":[\"tiff\"],\"name\":\"image\",\"optional\":true,\"req\":\"^0.25\",\"target\":\"cfg(target_os = \\\"macos\\\")\"},{\"default_features\":false,\"features\":[\"png\",\"bmp\"],\"name\":\"image\",\"optional\":true,\"req\":\"^0.25\",\"target\":\"cfg(windows)\"},{\"name\":\"log\",\"req\":\"^0.4\",\"target\":\"cfg(all(unix, not(any(target_os=\\\"macos\\\", target_os=\\\"android\\\", target_os=\\\"emscripten\\\"))))\"},{\"name\":\"log\",\"req\":\"^0.4\",\"target\":\"cfg(windows)\"},{\"name\":\"objc2\",\"req\":\"^0.6.0\",\"target\":\"cfg(target_os = \\\"macos\\\")\"},{\"default_features\":false,\"features\":[\"std\",\"objc2-core-graphics\",\"NSPasteboard\",\"NSPasteboardItem\",\"NSImage\"],\"name\":\"objc2-app-kit\",\"req\":\"^0.3.0\",\"target\":\"cfg(target_os = \\\"macos\\\")\"},{\"default_features\":false,\"features\":[\"std\",\"CFCGTypes\"],\"name\":\"objc2-core-foundation\",\"optional\":true,\"req\":\"^0.3.0\",\"target\":\"cfg(target_os = \\\"macos\\\")\"},{\"default_features\":false,\"features\":[\"std\",\"CGImage\",\"CGColorSpace\",\"CGDataProvider\"],\"name\":\"objc2-core-graphics\",\"optional\":true,\"req\":\"^0.3.0\",\"target\":\"cfg(target_os = \\\"macos\\\")\"},{\"default_features\":false,\"features\":[\"std\",\"NSArray\",\"NSString\",\"NSEnumerator\",\"NSGeometry\",\"NSValue\"],\"name\":\"objc2-foundation\",\"req\":\"^0.3.0\",\"target\":\"cfg(target_os = \\\"macos\\\")\"},{\"name\":\"parking_lot\",\"req\":\"^0.12\",\"target\":\"cfg(all(unix, not(any(target_os=\\\"macos\\\", target_os=\\\"android\\\", target_os=\\\"emscripten\\\"))))\"},{\"name\":\"percent-encoding\",\"req\":\"^2.3.1\",\"target\":\"cfg(all(unix, not(any(target_os=\\\"macos\\\", target_os=\\\"android\\\", target_os=\\\"emscripten\\\"))))\"},{\"features\":[\"Win32_Foundation\",\"Win32_Storage_FileSystem\",\"Win32_System_DataExchange\",\"Win32_System_Memory\",\"Win32_System_Ole\",\"Win32_UI_Shell\"],\"name\":\"windows-sys\",\"req\":\">=0.52.0, <0.61.0\",\"target\":\"cfg(windows)\"},{\"name\":\"wl-clipboard-rs\",\"optional\":true,\"req\":\"^0.9.0\",\"target\":\"cfg(all(unix, not(any(target_os=\\\"macos\\\", target_os=\\\"android\\\", target_os=\\\"emscripten\\\"))))\"},{\"name\":\"x11rb\",\"req\":\"^0.13\",\"target\":\"cfg(all(unix, not(any(target_os=\\\"macos\\\", target_os=\\\"android\\\", target_os=\\\"emscripten\\\"))))\"}],\"features\":{\"core-graphics\":[\"dep:objc2-core-graphics\"],\"default\":[\"image-data\"],\"image\":[\"dep:image\"],\"image-data\":[\"dep:objc2-core-graphics\",\"dep:objc2-core-foundation\",\"image\",\"windows-sys\",\"core-graphics\"],\"wayland-data-control\":[\"wl-clipboard-rs\"],\"windows-sys\":[\"windows-sys/Win32_Graphics_Gdi\"],\"wl-clipboard-rs\":[\"dep:wl-clipboard-rs\"]}}",
"arc-swap_1.8.2": "{\"dependencies\":[{\"kind\":\"dev\",\"name\":\"adaptive-barrier\",\"req\":\"~1\"},{\"kind\":\"dev\",\"name\":\"criterion\",\"req\":\"~0.7\"},{\"kind\":\"dev\",\"name\":\"crossbeam-utils\",\"req\":\"~0.8\"},{\"kind\":\"dev\",\"name\":\"itertools\",\"req\":\"^0.14\"},{\"kind\":\"dev\",\"name\":\"num_cpus\",\"req\":\"~1\"},{\"kind\":\"dev\",\"name\":\"once_cell\",\"req\":\"~1\"},{\"kind\":\"dev\",\"name\":\"parking_lot\",\"req\":\"~0.12\"},{\"kind\":\"dev\",\"name\":\"proptest\",\"req\":\"^1\"},{\"name\":\"rustversion\",\"req\":\"^1\"},{\"features\":[\"rc\"],\"name\":\"serde\",\"optional\":true,\"req\":\"^1\"},{\"kind\":\"dev\",\"name\":\"serde_derive\",\"req\":\"^1.0.130\"},{\"kind\":\"dev\",\"name\":\"serde_test\",\"req\":\"^1.0.177\"}],\"features\":{\"experimental-strategies\":[],\"experimental-thread-local\":[],\"internal-test-strategies\":[],\"weak\":[]}}",
"arc-swap_1.9.0": "{\"dependencies\":[{\"kind\":\"dev\",\"name\":\"adaptive-barrier\",\"req\":\"~1\"},{\"kind\":\"dev\",\"name\":\"criterion\",\"req\":\"~0.7\"},{\"kind\":\"dev\",\"name\":\"crossbeam-utils\",\"req\":\"~0.8\"},{\"kind\":\"dev\",\"name\":\"itertools\",\"req\":\"^0.14\"},{\"kind\":\"dev\",\"name\":\"num_cpus\",\"req\":\"~1\"},{\"kind\":\"dev\",\"name\":\"once_cell\",\"req\":\"~1\"},{\"kind\":\"dev\",\"name\":\"parking_lot\",\"req\":\"~0.12\"},{\"kind\":\"dev\",\"name\":\"proptest\",\"req\":\"^1\"},{\"name\":\"rustversion\",\"req\":\"^1\"},{\"features\":[\"rc\"],\"name\":\"serde\",\"optional\":true,\"req\":\"^1\"},{\"kind\":\"dev\",\"name\":\"serde_derive\",\"req\":\"^1.0.130\"},{\"kind\":\"dev\",\"name\":\"serde_test\",\"req\":\"^1.0.177\"}],\"features\":{\"experimental-strategies\":[],\"experimental-thread-local\":[],\"internal-test-strategies\":[],\"weak\":[]}}",
"arrayvec_0.7.6": "{\"dependencies\":[{\"kind\":\"dev\",\"name\":\"bencher\",\"req\":\"^0.1.4\"},{\"default_features\":false,\"name\":\"borsh\",\"optional\":true,\"req\":\"^1.2.0\"},{\"kind\":\"dev\",\"name\":\"matches\",\"req\":\"^0.1\"},{\"default_features\":false,\"name\":\"serde\",\"optional\":true,\"req\":\"^1.0\"},{\"kind\":\"dev\",\"name\":\"serde_test\",\"req\":\"^1.0\"},{\"default_features\":false,\"name\":\"zeroize\",\"optional\":true,\"req\":\"^1.4\"}],\"features\":{\"default\":[\"std\"],\"std\":[]}}",
"ascii-canvas_3.0.0": "{\"dependencies\":[{\"kind\":\"dev\",\"name\":\"diff\",\"req\":\"^0.1\"},{\"name\":\"term\",\"req\":\"^0.7\"}],\"features\":{}}",
"ascii_1.1.0": "{\"dependencies\":[{\"name\":\"serde\",\"optional\":true,\"req\":\"^1.0.25\"},{\"name\":\"serde_test\",\"optional\":true,\"req\":\"^1.0\"}],\"features\":{\"alloc\":[],\"default\":[\"std\"],\"std\":[\"alloc\"]}}",
+3 -2
View File
@@ -453,9 +453,9 @@ dependencies = [
[[package]]
name = "arc-swap"
version = "1.8.2"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9f3647c145568cec02c42054e07bdf9a5a698e15b466fb2341bfc393cd24aa5"
checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6"
dependencies = [
"rustversion",
]
@@ -2031,6 +2031,7 @@ name = "codex-exec-server"
version = "0.0.0"
dependencies = [
"anyhow",
"arc-swap",
"async-trait",
"base64 0.22.1",
"clap",
+1
View File
@@ -189,6 +189,7 @@ allocative = "0.3.3"
ansi-to-tui = "7.0.0"
anyhow = "1"
arboard = { version = "3", features = ["wayland-data-control"] }
arc-swap = "1.9.0"
assert_cmd = "2"
assert_matches = "1.5.0"
async-channel = "2.3.1"
+1 -1
View File
@@ -18,7 +18,7 @@ workspace = true
[dependencies]
anyhow = { workspace = true }
arc-swap = "1.8.2"
arc-swap = { workspace = true }
async-channel = { workspace = true }
async-trait = { workspace = true }
base64 = { workspace = true }
@@ -45,9 +45,12 @@ use futures::future::BoxFuture;
use std::collections::HashMap;
use std::path::PathBuf;
/// Request payload used by the unified-exec runtime after approvals and
/// sandbox preferences have been resolved for the current turn.
#[derive(Clone, Debug)]
pub struct UnifiedExecRequest {
pub command: Vec<String>,
pub process_id: i32,
pub cwd: PathBuf,
pub env: HashMap<String, String>,
pub explicit_env_overrides: HashMap<String, String>,
@@ -61,6 +64,8 @@ pub struct UnifiedExecRequest {
pub exec_approval_requirement: ExecApprovalRequirement,
}
/// Cache key for approval decisions that can be reused across equivalent
/// unified-exec launches.
#[derive(serde::Serialize, Clone, Debug, Eq, PartialEq, Hash)]
pub struct UnifiedExecApprovalKey {
pub command: Vec<String>,
@@ -70,12 +75,15 @@ pub struct UnifiedExecApprovalKey {
pub additional_permissions: Option<PermissionProfile>,
}
/// Runtime adapter that keeps policy and sandbox orchestration on the
/// unified-exec side while delegating process startup to the manager.
pub struct UnifiedExecRuntime<'a> {
manager: &'a UnifiedExecProcessManager,
shell_mode: UnifiedExecShellMode,
}
impl<'a> UnifiedExecRuntime<'a> {
/// Creates a runtime bound to the shared unified-exec process manager.
pub fn new(manager: &'a UnifiedExecProcessManager, shell_mode: UnifiedExecShellMode) -> Self {
Self {
manager,
@@ -232,12 +240,19 @@ impl<'a> ToolRuntime<UnifiedExecRequest, UnifiedExecProcess> for UnifiedExecRunt
.await?
{
Some(prepared) => {
if ctx.turn.environment.exec_server_url().is_some() {
return Err(ToolError::Rejected(
"unified_exec zsh-fork is not supported when exec_server_url is configured".to_string(),
));
}
return self
.manager
.open_session_with_exec_env(
req.process_id,
&prepared.exec_request,
req.tty,
prepared.spawn_lifecycle,
ctx.turn.environment.as_ref(),
)
.await
.map_err(|err| match err {
@@ -268,7 +283,13 @@ impl<'a> ToolRuntime<UnifiedExecRequest, UnifiedExecProcess> for UnifiedExecRunt
.env_for(command, options, req.network.as_ref())
.map_err(|err| ToolError::Codex(err.into()))?;
self.manager
.open_session_with_exec_env(&exec_env, req.tty, Box::new(NoopSpawnLifecycle))
.open_session_with_exec_env(
req.process_id,
&exec_env,
req.tty,
Box::new(NoopSpawnLifecycle),
ctx.turn.environment.as_ref(),
)
.await
.map_err(|err| match err {
UnifiedExecError::SandboxDenied { output, .. } => {
+76 -14
View File
@@ -20,6 +20,7 @@ use crate::protocol::ExecCommandSource;
use crate::protocol::ExecOutputStream;
use crate::tools::events::ToolEmitter;
use crate::tools::events::ToolEventCtx;
use crate::tools::events::ToolEventFailure;
use crate::tools::events::ToolEventStage;
use crate::unified_exec::head_tail_buffer::HeadTailBuffer;
@@ -121,21 +122,36 @@ pub(crate) fn spawn_exit_watcher(
exit_token.cancelled().await;
output_drained.notified().await;
let exit_code = process.exit_code().unwrap_or(-1);
let duration = Instant::now().saturating_duration_since(started_at);
emit_exec_end_for_unified_exec(
session_ref,
turn_ref,
call_id,
command,
cwd,
Some(process_id.to_string()),
transcript,
String::new(),
exit_code,
duration,
)
.await;
if let Some(message) = process.failure_message() {
emit_failed_exec_end_for_unified_exec(
session_ref,
turn_ref,
call_id,
command,
cwd,
Some(process_id.to_string()),
transcript,
message,
duration,
)
.await;
} else {
let exit_code = process.exit_code().unwrap_or(-1);
emit_exec_end_for_unified_exec(
session_ref,
turn_ref,
call_id,
command,
cwd,
Some(process_id.to_string()),
transcript,
String::new(),
exit_code,
duration,
)
.await;
}
});
}
@@ -213,6 +229,52 @@ pub(crate) async fn emit_exec_end_for_unified_exec(
.await;
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn emit_failed_exec_end_for_unified_exec(
session_ref: Arc<Session>,
turn_ref: Arc<TurnContext>,
call_id: String,
command: Vec<String>,
cwd: PathBuf,
process_id: Option<String>,
transcript: Arc<Mutex<HeadTailBuffer>>,
message: String,
duration: Duration,
) {
let stdout = resolve_aggregated_output(&transcript, String::new()).await;
let aggregated_output = if stdout.is_empty() {
message.clone()
} else {
format!("{stdout}\n{message}")
};
let output = ExecToolCallOutput {
exit_code: -1,
stdout: StreamOutput::new(stdout),
stderr: StreamOutput::new(message),
aggregated_output: StreamOutput::new(aggregated_output),
duration,
timed_out: false,
};
let event_ctx = ToolEventCtx::new(
session_ref.as_ref(),
turn_ref.as_ref(),
&call_id,
/*turn_diff_tracker*/ None,
);
let emitter = ToolEmitter::unified_exec(
&command,
cwd,
ExecCommandSource::UnifiedExecStartup,
process_id,
);
emitter
.emit(
event_ctx,
ToolEventStage::Failure(ToolEventFailure::Output(output)),
)
.await;
}
fn split_valid_utf8_prefix(buffer: &mut Vec<u8>) -> Option<Vec<u8>> {
split_valid_utf8_prefix_with_max(buffer, UNIFIED_EXEC_OUTPUT_DELTA_MAX_BYTES)
}
+6
View File
@@ -5,6 +5,8 @@ use thiserror::Error;
pub(crate) enum UnifiedExecError {
#[error("Failed to create unified exec process: {message}")]
CreateProcess { message: String },
#[error("Unified exec process failed: {message}")]
ProcessFailed { message: String },
// The model is trained on `session_id`, but internally we track a `process_id`.
#[error("Unknown process id {process_id}")]
UnknownProcessId { process_id: i32 },
@@ -28,6 +30,10 @@ impl UnifiedExecError {
Self::CreateProcess { message }
}
pub(crate) fn process_failed(message: String) -> Self {
Self::ProcessFailed { message }
}
pub(crate) fn sandbox_denied(message: String, output: ExecToolCallOutput) -> Self {
Self::SandboxDenied { message, output }
}
+6
View File
@@ -19,6 +19,7 @@
//! This keeps policy logic and user interaction centralized while the PTY/process
//! concerns remain isolated here. The implementation is split between:
//! - `process.rs`: PTY process lifecycle + output buffering.
//! - `process_state.rs`: shared exit/failure state for local and remote processes.
//! - `process_manager.rs`: orchestration (approvals, sandboxing, reuse) and request handling.
use std::collections::HashMap;
@@ -42,6 +43,7 @@ mod errors;
mod head_tail_buffer;
mod process;
mod process_manager;
mod process_state;
pub(crate) fn set_deterministic_process_ids_for_tests(enabled: bool) {
process_manager::set_deterministic_process_ids_for_tests(enabled);
@@ -167,6 +169,10 @@ pub(crate) fn generate_chunk_id() -> String {
.collect()
}
#[cfg(test)]
#[cfg(unix)]
#[path = "process_tests.rs"]
mod process_tests;
#[cfg(test)]
#[cfg(unix)]
#[path = "mod_tests.rs"]
+295 -48
View File
@@ -3,27 +3,26 @@ use super::*;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::codex::make_session_and_context;
use crate::protocol::AskForApproval;
use crate::protocol::SandboxPolicy;
use crate::exec::ExecCapturePolicy;
use crate::exec::ExecExpiration;
use crate::sandboxing::ExecRequest;
use crate::tools::context::ExecCommandToolOutput;
use crate::unified_exec::ExecCommandRequest;
use crate::unified_exec::WriteStdinRequest;
use crate::unified_exec::process::OutputHandles;
use codex_sandboxing::SandboxType;
use codex_utils_output_truncation::approx_token_count;
use core_test_support::get_remote_test_env;
use core_test_support::skip_if_sandbox;
use core_test_support::test_codex::test_env as remote_test_env;
use pretty_assertions::assert_eq;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::time::Duration;
use tokio::time::Instant;
async fn test_session_and_turn() -> (Arc<Session>, Arc<TurnContext>) {
let (session, mut turn) = make_session_and_context().await;
turn.approval_policy
.set(AskForApproval::Never)
.expect("test setup should allow updating approval policy");
turn.sandbox_policy
.set(SandboxPolicy::DangerFullAccess)
.expect("test setup should allow updating sandbox policy");
turn.file_system_sandbox_policy =
codex_protocol::permissions::FileSystemSandboxPolicy::from(turn.sandbox_policy.get());
turn.network_sandbox_policy =
codex_protocol::permissions::NetworkSandboxPolicy::from(turn.sandbox_policy.get());
let (session, turn) = make_session_and_context().await;
(Arc::new(session), Arc::new(turn))
}
@@ -32,36 +31,143 @@ async fn exec_command(
turn: &Arc<TurnContext>,
cmd: &str,
yield_time_ms: u64,
workdir: Option<PathBuf>,
) -> Result<ExecCommandToolOutput, UnifiedExecError> {
exec_command_with_tty(session, turn, cmd, yield_time_ms, workdir, true).await
}
fn shell_env() -> HashMap<String, String> {
std::env::vars().collect()
}
fn test_exec_request(
turn: &TurnContext,
command: Vec<String>,
cwd: PathBuf,
env: HashMap<String, String>,
) -> ExecRequest {
let windows_sandbox_private_desktop = false;
let sandbox_policy = turn.sandbox_policy.get().clone();
let file_system_sandbox_policy = turn.file_system_sandbox_policy.clone();
let network_sandbox_policy = turn.network_sandbox_policy;
let network = None;
let arg0 = None;
ExecRequest::new(
command,
cwd,
env,
network,
ExecExpiration::DefaultTimeout,
ExecCapturePolicy::ShellTool,
SandboxType::None,
turn.windows_sandbox_level,
windows_sandbox_private_desktop,
sandbox_policy,
file_system_sandbox_policy,
network_sandbox_policy,
arg0,
)
}
async fn exec_command_with_tty(
session: &Arc<Session>,
turn: &Arc<TurnContext>,
cmd: &str,
yield_time_ms: u64,
workdir: Option<PathBuf>,
tty: bool,
) -> Result<ExecCommandToolOutput, UnifiedExecError> {
let manager = &session.services.unified_exec_manager;
let process_id = manager.allocate_process_id().await;
let cwd = workdir.unwrap_or_else(|| turn.cwd.clone().to_path_buf());
let command = vec!["bash".to_string(), "-lc".to_string(), cmd.to_string()];
let request = test_exec_request(turn, command.clone(), cwd.clone(), shell_env());
let process = Arc::new(
manager
.open_session_with_exec_env(
process_id,
&request,
tty,
Box::new(NoopSpawnLifecycle),
turn.environment.as_ref(),
)
.await?,
);
let context =
UnifiedExecContext::new(Arc::clone(session), Arc::clone(turn), "call".to_string());
let process_id = session
.services
.unified_exec_manager
.allocate_process_id()
.await;
let started_at = Instant::now();
let process_started_alive = !process.has_exited() && process.exit_code().is_none();
if process_started_alive {
let entry = ProcessEntry {
process: Arc::clone(&process),
call_id: context.call_id.clone(),
process_id,
command: command.clone(),
tty,
network_approval_id: None,
session: Arc::downgrade(session),
last_used: started_at,
};
manager
.process_store
.lock()
.await
.processes
.insert(process_id, entry);
}
session
.services
.unified_exec_manager
.exec_command(
ExecCommandRequest {
command: vec!["bash".to_string(), "-lc".to_string(), cmd.to_string()],
process_id,
yield_time_ms,
max_output_tokens: None,
workdir: None,
network: None,
tty: true,
sandbox_permissions: SandboxPermissions::UseDefault,
additional_permissions: None,
additional_permissions_preapproved: false,
justification: None,
prefix_rule: None,
},
&context,
)
.await
let OutputHandles {
output_buffer,
output_notify,
output_closed,
output_closed_notify,
cancellation_token,
} = process.output_handles();
let deadline = started_at + Duration::from_millis(yield_time_ms);
let collected = UnifiedExecProcessManager::collect_output_until_deadline(
&output_buffer,
&output_notify,
&output_closed,
&output_closed_notify,
&cancellation_token,
Some(session.subscribe_out_of_band_elicitation_pause_state()),
deadline,
)
.await;
let wall_time = Instant::now().saturating_duration_since(started_at);
let text = String::from_utf8_lossy(&collected).to_string();
let has_exited = process.has_exited();
let exit_code = process.exit_code();
let response_process_id = if process_started_alive && !has_exited {
Some(process_id)
} else {
manager.release_process_id(process_id).await;
None
};
Ok(ExecCommandToolOutput {
event_call_id: context.call_id,
chunk_id: generate_chunk_id(),
wall_time,
raw_output: collected,
max_output_tokens: None,
process_id: response_process_id,
exit_code,
original_token_count: Some(approx_token_count(&text)),
session_command: Some(command),
})
}
#[derive(Debug)]
struct TestSpawnLifecycle {
inherited_fds: Vec<i32>,
}
impl SpawnLifecycle for TestSpawnLifecycle {
fn inherited_fds(&self) -> Vec<i32> {
self.inherited_fds.clone()
}
}
async fn write_stdin(
@@ -121,7 +227,7 @@ async fn unified_exec_persists_across_requests() -> anyhow::Result<()> {
let (session, turn) = test_session_and_turn().await;
let open_shell = exec_command(&session, &turn, "bash -i", 2_500).await?;
let open_shell = exec_command(&session, &turn, "bash -i", 2_500, None).await?;
let process_id = open_shell.process_id.expect("expected process_id");
write_stdin(
@@ -153,7 +259,7 @@ async fn multi_unified_exec_sessions() -> anyhow::Result<()> {
let (session, turn) = test_session_and_turn().await;
let shell_a = exec_command(&session, &turn, "bash -i", 2_500).await?;
let shell_a = exec_command(&session, &turn, "bash -i", 2_500, None).await?;
let session_a = shell_a.process_id.expect("expected process id");
write_stdin(
@@ -164,7 +270,14 @@ async fn multi_unified_exec_sessions() -> anyhow::Result<()> {
)
.await?;
let out_2 = exec_command(&session, &turn, "echo $CODEX_INTERACTIVE_SHELL_VAR", 2_500).await?;
let out_2 = exec_command(
&session,
&turn,
"echo $CODEX_INTERACTIVE_SHELL_VAR",
2_500,
None,
)
.await?;
tokio::time::sleep(Duration::from_secs(2)).await;
assert!(
out_2.process_id.is_none(),
@@ -198,7 +311,7 @@ async fn unified_exec_timeouts() -> anyhow::Result<()> {
let (session, turn) = test_session_and_turn().await;
let open_shell = exec_command(&session, &turn, "bash -i", 2_500).await?;
let open_shell = exec_command(&session, &turn, "bash -i", 2_500, None).await?;
let process_id = open_shell.process_id.expect("expected process id");
write_stdin(
@@ -247,7 +360,14 @@ async fn unified_exec_pause_blocks_yield_timeout() -> anyhow::Result<()> {
});
let started = tokio::time::Instant::now();
let response = exec_command(&session, &turn, "sleep 1 && echo unified-exec-done", 250).await?;
let response = exec_command(
&session,
&turn,
"sleep 1 && echo unified-exec-done",
250,
None,
)
.await?;
assert!(
started.elapsed() >= Duration::from_secs(2),
@@ -270,7 +390,7 @@ async fn unified_exec_pause_blocks_yield_timeout() -> anyhow::Result<()> {
async fn requests_with_large_timeout_are_capped() -> anyhow::Result<()> {
let (session, turn) = test_session_and_turn().await;
let result = exec_command(&session, &turn, "echo codex", 120_000).await?;
let result = exec_command(&session, &turn, "echo codex", 120_000, None).await?;
assert!(result.process_id.is_some());
assert!(result.truncated_output().contains("codex"));
@@ -282,7 +402,7 @@ async fn requests_with_large_timeout_are_capped() -> anyhow::Result<()> {
#[ignore] // Ignored while we have a better way to test this.
async fn completed_commands_do_not_persist_sessions() -> anyhow::Result<()> {
let (session, turn) = test_session_and_turn().await;
let result = exec_command(&session, &turn, "echo codex", 2_500).await?;
let result = exec_command(&session, &turn, "echo codex", 2_500, None).await?;
assert!(
result.process_id.is_some(),
@@ -310,7 +430,7 @@ async fn reusing_completed_process_returns_unknown_process() -> anyhow::Result<(
let (session, turn) = test_session_and_turn().await;
let open_shell = exec_command(&session, &turn, "bash -i", 2_500).await?;
let open_shell = exec_command(&session, &turn, "bash -i", 2_500, None).await?;
let process_id = open_shell.process_id.expect("expected process id");
write_stdin(&session, process_id, "exit\n", 2_500).await?;
@@ -341,3 +461,130 @@ async fn reusing_completed_process_returns_unknown_process() -> anyhow::Result<(
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn completed_pipe_commands_preserve_exit_code() -> anyhow::Result<()> {
let (_, turn) = make_session_and_context().await;
let request = test_exec_request(
&turn,
vec!["bash".to_string(), "-lc".to_string(), "exit 17".to_string()],
PathBuf::from("/tmp"),
shell_env(),
);
let environment = codex_exec_server::Environment::default();
let process = UnifiedExecProcessManager::default()
.open_session_with_exec_env(
1234,
&request,
false,
Box::new(NoopSpawnLifecycle),
&environment,
)
.await?;
if !process.has_exited() {
let exit_signal = process.cancellation_token();
assert!(
tokio::time::timeout(Duration::from_secs(2), exit_signal.cancelled())
.await
.is_ok(),
"process did not report exit within timeout"
);
}
assert!(process.has_exited());
assert_eq!(process.exit_code(), Some(17));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn unified_exec_uses_remote_exec_server_when_configured() -> anyhow::Result<()> {
skip_if_sandbox!(Ok(()));
let Some(_remote_env) = get_remote_test_env() else {
return Ok(());
};
let remote_test_env = remote_test_env().await?;
let (_, turn) = make_session_and_context().await;
let request = test_exec_request(
&turn,
vec!["bash".to_string(), "-i".to_string()],
PathBuf::from("/tmp"),
shell_env(),
);
let manager = UnifiedExecProcessManager::default();
let process = manager
.open_session_with_exec_env(
1234,
&request,
true,
Box::new(NoopSpawnLifecycle),
remote_test_env.environment(),
)
.await?;
process.write(b"printf 'remote-unified-exec\\n'\n").await?;
tokio::time::sleep(Duration::from_millis(100)).await;
let crate::unified_exec::process::OutputHandles {
output_buffer,
output_notify,
output_closed,
output_closed_notify,
cancellation_token,
} = process.output_handles();
let collected = UnifiedExecProcessManager::collect_output_until_deadline(
&output_buffer,
&output_notify,
&output_closed,
&output_closed_notify,
&cancellation_token,
None,
Instant::now() + Duration::from_millis(2_500),
)
.await;
assert!(String::from_utf8_lossy(&collected).contains("remote-unified-exec"));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_exec_server_rejects_inherited_fd_launches() -> anyhow::Result<()> {
skip_if_sandbox!(Ok(()));
let Some(_remote_env) = get_remote_test_env() else {
return Ok(());
};
let remote_test_env = remote_test_env().await?;
let (_, mut turn) = make_session_and_context().await;
turn.environment = Arc::new(remote_test_env.environment().clone());
let request = test_exec_request(
&turn,
vec!["bash".to_string(), "-lc".to_string(), "echo ok".to_string()],
PathBuf::from("/tmp"),
shell_env(),
);
let manager = UnifiedExecProcessManager::default();
let err = manager
.open_session_with_exec_env(
1234,
&request,
true,
Box::new(TestSpawnLifecycle {
inherited_fds: vec![42],
}),
turn.environment.as_ref(),
)
.await
.expect_err("expected inherited fd rejection");
assert_eq!(
err.to_string(),
"Failed to create unified exec process: remote exec-server does not support inherited file descriptors"
);
Ok(())
}
+282 -58
View File
@@ -6,8 +6,8 @@ use std::sync::atomic::Ordering;
use tokio::sync::Mutex;
use tokio::sync::Notify;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::sync::oneshot::error::TryRecvError;
use tokio::sync::watch;
use tokio::task::JoinHandle;
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
@@ -15,8 +15,12 @@ use tokio_util::sync::CancellationToken;
use crate::exec::ExecToolCallOutput;
use crate::exec::StreamOutput;
use crate::exec::is_likely_sandbox_denied;
use codex_exec_server::ExecProcess;
use codex_exec_server::ReadResponse as ExecReadResponse;
use codex_exec_server::StartedExecProcess;
use codex_exec_server::WriteStatus;
use codex_protocol::protocol::TruncationPolicy;
use codex_sandboxing::SandboxType;
use codex_utils_output_truncation::TruncationPolicy;
use codex_utils_output_truncation::formatted_truncate_text;
use codex_utils_pty::ExecCommandSession;
use codex_utils_pty::SpawnedPty;
@@ -24,6 +28,9 @@ use codex_utils_pty::SpawnedPty;
use super::UNIFIED_EXEC_OUTPUT_MAX_TOKENS;
use super::UnifiedExecError;
use super::head_tail_buffer::HeadTailBuffer;
use super::process_state::ProcessState;
const EARLY_EXIT_GRACE_PERIOD: Duration = Duration::from_millis(150);
pub(crate) trait SpawnLifecycle: std::fmt::Debug + Send + Sync {
/// Returns file descriptors that must stay open across the child `exec()`.
@@ -41,11 +48,13 @@ pub(crate) trait SpawnLifecycle: std::fmt::Debug + Send + Sync {
pub(crate) type SpawnLifecycleHandle = Box<dyn SpawnLifecycle>;
#[derive(Debug, Default)]
/// Spawn lifecycle that performs no extra setup around process launch.
pub(crate) struct NoopSpawnLifecycle;
impl SpawnLifecycle for NoopSpawnLifecycle {}
pub(crate) type OutputBuffer = Arc<Mutex<HeadTailBuffer>>;
/// Shared output state exposed to polling and streaming consumers.
pub(crate) struct OutputHandles {
pub(crate) output_buffer: OutputBuffer,
pub(crate) output_notify: Arc<Notify>,
@@ -54,27 +63,44 @@ pub(crate) struct OutputHandles {
pub(crate) cancellation_token: CancellationToken,
}
#[derive(Debug)]
/// Transport-specific process handle used by unified exec.
enum ProcessHandle {
Local(Box<ExecCommandSession>),
Remote(Arc<dyn ExecProcess>),
}
/// Unified wrapper over local PTY sessions and exec-server-backed processes.
pub(crate) struct UnifiedExecProcess {
process_handle: ExecCommandSession,
output_rx: broadcast::Receiver<Vec<u8>>,
process_handle: ProcessHandle,
output_tx: broadcast::Sender<Vec<u8>>,
output_buffer: OutputBuffer,
output_notify: Arc<Notify>,
output_closed: Arc<AtomicBool>,
output_closed_notify: Arc<Notify>,
cancellation_token: CancellationToken,
output_drained: Arc<Notify>,
output_task: JoinHandle<()>,
state_tx: watch::Sender<ProcessState>,
state_rx: watch::Receiver<ProcessState>,
output_task: Option<JoinHandle<()>>,
sandbox_type: SandboxType,
_spawn_lifecycle: SpawnLifecycleHandle,
_spawn_lifecycle: Option<SpawnLifecycleHandle>,
}
impl std::fmt::Debug for UnifiedExecProcess {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UnifiedExecProcess")
.field("has_exited", &self.has_exited())
.field("exit_code", &self.exit_code())
.field("sandbox_type", &self.sandbox_type)
.finish_non_exhaustive()
}
}
impl UnifiedExecProcess {
pub(super) fn new(
process_handle: ExecCommandSession,
initial_output_rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
fn new(
process_handle: ProcessHandle,
sandbox_type: SandboxType,
spawn_lifecycle: SpawnLifecycleHandle,
spawn_lifecycle: Option<SpawnLifecycleHandle>,
) -> Self {
let output_buffer = Arc::new(Mutex::new(HeadTailBuffer::default()));
let output_notify = Arc::new(Notify::new());
@@ -82,48 +108,49 @@ impl UnifiedExecProcess {
let output_closed_notify = Arc::new(Notify::new());
let cancellation_token = CancellationToken::new();
let output_drained = Arc::new(Notify::new());
let mut receiver = initial_output_rx;
let output_rx = receiver.resubscribe();
let buffer_clone = Arc::clone(&output_buffer);
let notify_clone = Arc::clone(&output_notify);
let output_closed_clone = Arc::clone(&output_closed);
let output_closed_notify_clone = Arc::clone(&output_closed_notify);
let output_task = tokio::spawn(async move {
loop {
match receiver.recv().await {
Ok(chunk) => {
let mut guard = buffer_clone.lock().await;
guard.push_chunk(chunk);
drop(guard);
notify_clone.notify_waiters();
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
output_closed_clone.store(true, Ordering::Release);
output_closed_notify_clone.notify_waiters();
break;
}
};
}
});
let (output_tx, _) = broadcast::channel(64);
let (state_tx, state_rx) = watch::channel(ProcessState::default());
Self {
process_handle,
output_rx,
output_tx,
output_buffer,
output_notify,
output_closed,
output_closed_notify,
cancellation_token,
output_drained,
output_task,
state_tx,
state_rx,
output_task: None,
sandbox_type,
_spawn_lifecycle: spawn_lifecycle,
}
}
pub(super) fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
self.process_handle.writer_sender()
pub(super) async fn write(&self, data: &[u8]) -> Result<(), UnifiedExecError> {
match &self.process_handle {
ProcessHandle::Local(process_handle) => process_handle
.writer_sender()
.send(data.to_vec())
.await
.map_err(|_| UnifiedExecError::WriteToStdin),
ProcessHandle::Remote(process_handle) => {
match process_handle.write(data.to_vec()).await {
Ok(response) => match response.status {
WriteStatus::Accepted => Ok(()),
WriteStatus::UnknownProcess | WriteStatus::StdinClosed => {
let state = self.state_rx.borrow().clone();
let _ = self.state_tx.send_replace(state.exited(state.exit_code));
self.cancellation_token.cancel();
Err(UnifiedExecError::WriteToStdin)
}
WriteStatus::Starting => Err(UnifiedExecError::WriteToStdin),
},
Err(err) => Err(UnifiedExecError::process_failed(err.to_string())),
}
}
}
}
pub(super) fn output_handles(&self) -> OutputHandles {
@@ -137,7 +164,7 @@ impl UnifiedExecProcess {
}
pub(super) fn output_receiver(&self) -> tokio::sync::broadcast::Receiver<Vec<u8>> {
self.output_rx.resubscribe()
self.output_tx.subscribe()
}
pub(super) fn cancellation_token(&self) -> CancellationToken {
@@ -149,19 +176,39 @@ impl UnifiedExecProcess {
}
pub(super) fn has_exited(&self) -> bool {
self.process_handle.has_exited()
let state = self.state_rx.borrow().clone();
match &self.process_handle {
ProcessHandle::Local(process_handle) => state.has_exited || process_handle.has_exited(),
ProcessHandle::Remote(_) => state.has_exited,
}
}
pub(super) fn exit_code(&self) -> Option<i32> {
self.process_handle.exit_code()
let state = self.state_rx.borrow().clone();
match &self.process_handle {
ProcessHandle::Local(process_handle) => {
state.exit_code.or_else(|| process_handle.exit_code())
}
ProcessHandle::Remote(_) => state.exit_code,
}
}
pub(super) fn terminate(&self) {
self.output_closed.store(true, Ordering::Release);
self.output_closed_notify.notify_waiters();
self.process_handle.terminate();
match &self.process_handle {
ProcessHandle::Local(process_handle) => process_handle.terminate(),
ProcessHandle::Remote(process_handle) => {
let process_handle = Arc::clone(process_handle);
tokio::spawn(async move {
let _ = process_handle.terminate().await;
});
}
}
self.cancellation_token.cancel();
self.output_task.abort();
if let Some(output_task) = &self.output_task {
output_task.abort();
}
}
async fn snapshot_output(&self) -> Vec<Vec<u8>> {
@@ -173,6 +220,10 @@ impl UnifiedExecProcess {
self.sandbox_type
}
pub(super) fn failure_message(&self) -> Option<String> {
self.state_rx.borrow().failure_message.clone()
}
pub(super) async fn check_for_sandbox_denial(&self) -> Result<(), UnifiedExecError> {
let _ =
tokio::time::timeout(Duration::from_millis(20), self.output_notify.notified()).await;
@@ -232,29 +283,47 @@ impl UnifiedExecProcess {
mut exit_rx,
} = spawned;
let output_rx = codex_utils_pty::combine_output_receivers(stdout_rx, stderr_rx);
let managed = Self::new(process_handle, output_rx, sandbox_type, spawn_lifecycle);
let mut managed = Self::new(
ProcessHandle::Local(Box::new(process_handle)),
sandbox_type,
Some(spawn_lifecycle),
);
managed.output_task = Some(Self::spawn_local_output_task(
output_rx,
Arc::clone(&managed.output_buffer),
Arc::clone(&managed.output_notify),
Arc::clone(&managed.output_closed),
Arc::clone(&managed.output_closed_notify),
managed.output_tx.clone(),
));
let exit_ready = matches!(exit_rx.try_recv(), Ok(_) | Err(TryRecvError::Closed));
if exit_ready {
managed.signal_exit();
managed.check_for_sandbox_denial().await?;
return Ok(managed);
match exit_rx.try_recv() {
Ok(exit_code) => {
managed.signal_exit(Some(exit_code));
managed.check_for_sandbox_denial().await?;
return Ok(managed);
}
Err(TryRecvError::Closed) => {
managed.signal_exit(/*exit_code*/ None);
managed.check_for_sandbox_denial().await?;
return Ok(managed);
}
Err(TryRecvError::Empty) => {}
}
if tokio::time::timeout(Duration::from_millis(150), &mut exit_rx)
.await
.is_ok()
{
managed.signal_exit();
if let Ok(exit_result) = tokio::time::timeout(EARLY_EXIT_GRACE_PERIOD, &mut exit_rx).await {
managed.signal_exit(exit_result.ok());
managed.check_for_sandbox_denial().await?;
return Ok(managed);
}
tokio::spawn({
let state_tx = managed.state_tx.clone();
let cancellation_token = managed.cancellation_token.clone();
async move {
let _ = exit_rx.await;
let exit_code = exit_rx.await.ok();
let state = state_tx.borrow().clone();
let _ = state_tx.send_replace(state.exited(exit_code));
cancellation_token.cancel();
}
});
@@ -262,7 +331,162 @@ impl UnifiedExecProcess {
Ok(managed)
}
fn signal_exit(&self) {
pub(super) async fn from_remote_started(
started: StartedExecProcess,
sandbox_type: SandboxType,
) -> Result<Self, UnifiedExecError> {
let process_handle = ProcessHandle::Remote(Arc::clone(&started.process));
let mut managed = Self::new(process_handle, sandbox_type, /*spawn_lifecycle*/ None);
let output_handles = managed.output_handles();
managed.output_task = Some(Self::spawn_remote_output_task(
started,
output_handles,
managed.output_tx.clone(),
managed.state_tx.clone(),
));
let mut state_rx = managed.state_rx.clone();
if tokio::time::timeout(EARLY_EXIT_GRACE_PERIOD, async {
loop {
let state = state_rx.borrow().clone();
if state.has_exited || state.failure_message.is_some() {
break;
}
if state_rx.changed().await.is_err() {
break;
}
}
})
.await
.is_ok()
{
managed.check_for_sandbox_denial().await?;
}
Ok(managed)
}
fn spawn_remote_output_task(
started: StartedExecProcess,
output_handles: OutputHandles,
output_tx: broadcast::Sender<Vec<u8>>,
state_tx: watch::Sender<ProcessState>,
) -> JoinHandle<()> {
let OutputHandles {
output_buffer,
output_notify,
output_closed,
output_closed_notify,
cancellation_token,
} = output_handles;
let process = started.process;
let mut wake_rx = process.subscribe_wake();
tokio::spawn(async move {
let mut after_seq = None;
loop {
match process
.read(after_seq, /*max_bytes*/ None, /*wait_ms*/ Some(0))
.await
{
Ok(response) => {
let ExecReadResponse {
chunks,
next_seq,
exited,
exit_code,
closed,
failure,
} = response;
for chunk in chunks {
let bytes = chunk.chunk.into_inner();
let mut guard = output_buffer.lock().await;
guard.push_chunk(bytes.clone());
drop(guard);
let _ = output_tx.send(bytes);
output_notify.notify_waiters();
}
if let Some(message) = failure {
let state = state_tx.borrow().clone();
let _ = state_tx.send_replace(state.failed(message));
output_closed.store(true, Ordering::Release);
output_closed_notify.notify_waiters();
cancellation_token.cancel();
break;
}
if exited {
let state = state_tx.borrow().clone();
let _ = state_tx.send_replace(state.exited(exit_code));
}
if closed {
output_closed.store(true, Ordering::Release);
output_closed_notify.notify_waiters();
cancellation_token.cancel();
}
after_seq = next_seq.checked_sub(1);
if output_closed.load(Ordering::Acquire) {
break;
}
}
Err(err) => {
let state = state_tx.borrow().clone();
let _ = state_tx.send_replace(state.failed(err.to_string()));
output_closed.store(true, Ordering::Release);
output_closed_notify.notify_waiters();
cancellation_token.cancel();
break;
}
}
if wake_rx.changed().await.is_err() {
let state = state_tx.borrow().clone();
let _ = state_tx
.send_replace(state.failed("exec-server wake channel closed".to_string()));
output_closed.store(true, Ordering::Release);
output_closed_notify.notify_waiters();
cancellation_token.cancel();
break;
}
}
})
}
fn spawn_local_output_task(
mut receiver: tokio::sync::broadcast::Receiver<Vec<u8>>,
buffer: OutputBuffer,
output_notify: Arc<Notify>,
output_closed: Arc<AtomicBool>,
output_closed_notify: Arc<Notify>,
output_tx: broadcast::Sender<Vec<u8>>,
) -> JoinHandle<()> {
tokio::spawn(async move {
loop {
match receiver.recv().await {
Ok(chunk) => {
let mut guard = buffer.lock().await;
guard.push_chunk(chunk.clone());
drop(guard);
let _ = output_tx.send(chunk);
output_notify.notify_waiters();
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
output_closed.store(true, Ordering::Release);
output_closed_notify.notify_waiters();
break;
}
};
}
})
}
fn signal_exit(&self, exit_code: Option<i32>) {
let state = self.state_rx.borrow().clone();
let _ = self.state_tx.send_replace(state.exited(exit_code));
self.cancellation_token.cancel();
}
}
@@ -7,7 +7,6 @@ use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use tokio::sync::Notify;
use tokio::sync::mpsc;
use tokio::sync::watch;
use tokio::time::Duration;
use tokio::time::Instant;
@@ -40,6 +39,7 @@ use crate::unified_exec::UnifiedExecProcessManager;
use crate::unified_exec::WARNING_UNIFIED_EXEC_PROCESSES;
use crate::unified_exec::WriteStdinRequest;
use crate::unified_exec::async_watcher::emit_exec_end_for_unified_exec;
use crate::unified_exec::async_watcher::emit_failed_exec_end_for_unified_exec;
use crate::unified_exec::async_watcher::spawn_exit_watcher;
use crate::unified_exec::async_watcher::start_streaming_output;
use crate::unified_exec::clamp_yield_time;
@@ -89,8 +89,9 @@ fn apply_unified_exec_env(mut env: HashMap<String, String>) -> HashMap<String, S
env
}
/// Borrowed process state prepared for a `write_stdin` or poll operation.
struct PreparedProcessHandles {
writer_tx: mpsc::Sender<Vec<u8>>,
process: Arc<UnifiedExecProcess>,
output_buffer: OutputBuffer,
output_notify: Arc<Notify>,
output_closed: Arc<AtomicBool>,
@@ -102,6 +103,10 @@ struct PreparedProcessHandles {
tty: bool,
}
fn exec_server_process_id(process_id: i32) -> String {
process_id.to_string()
}
impl UnifiedExecProcessManager {
pub(crate) async fn allocate_process_id(&self) -> i32 {
loop {
@@ -243,6 +248,29 @@ impl UnifiedExecProcessManager {
let text = String::from_utf8_lossy(&collected).to_string();
let chunk_id = generate_chunk_id();
if let Some(message) = process.failure_message() {
if !process_started_alive {
emit_failed_exec_end_for_unified_exec(
Arc::clone(&context.session),
Arc::clone(&context.turn),
context.call_id.clone(),
request.command.clone(),
cwd.clone(),
Some(request.process_id.to_string()),
Arc::clone(&transcript),
message.clone(),
wall_time,
)
.await;
}
self.release_process_id(request.process_id).await;
finish_deferred_network_approval(
context.session.as_ref(),
deferred_network_approval.take(),
)
.await;
return Err(UnifiedExecError::process_failed(message));
}
let process_id = request.process_id;
let (response_process_id, exit_code) = if process_started_alive {
match self.refresh_process_state(process_id).await {
@@ -312,7 +340,7 @@ impl UnifiedExecProcessManager {
let process_id = request.process_id;
let PreparedProcessHandles {
writer_tx,
process,
output_buffer,
output_notify,
output_closed,
@@ -324,15 +352,31 @@ impl UnifiedExecProcessManager {
tty,
..
} = self.prepare_process_handles(process_id).await?;
let mut status_after_write = None;
if !request.input.is_empty() {
if !tty {
return Err(UnifiedExecError::StdinClosed);
}
Self::send_input(&writer_tx, request.input.as_bytes()).await?;
// Give the remote process a brief window to react so that we are
// more likely to capture its output in the poll below.
tokio::time::sleep(Duration::from_millis(100)).await;
match process.write(request.input.as_bytes()).await {
Ok(()) => {
// Give the remote process a brief window to react so that we are
// more likely to capture its output in the poll below.
tokio::time::sleep(Duration::from_millis(100)).await;
}
Err(err) => {
let status = self.refresh_process_state(process_id).await;
if matches!(status, ProcessStatus::Exited { .. }) {
status_after_write = Some(status);
} else if matches!(err, UnifiedExecError::ProcessFailed { .. }) {
process.terminate();
self.release_process_id(process_id).await;
return Err(err);
} else {
return Err(err);
}
}
}
}
let yield_time_ms = {
@@ -362,12 +406,20 @@ impl UnifiedExecProcessManager {
let text = String::from_utf8_lossy(&collected).to_string();
let original_token_count = approx_token_count(&text);
let chunk_id = generate_chunk_id();
if let Some(message) = process.failure_message() {
self.release_process_id(process_id).await;
return Err(UnifiedExecError::process_failed(message));
}
// After polling, refresh_process_state tells us whether the PTY is
// still alive or has exited and been removed from the store; we thread
// that through so the handler can tag TerminalInteraction with an
// appropriate process_id and exit_code.
let status = self.refresh_process_state(process_id).await;
let status = if let Some(status) = status_after_write {
status
} else {
self.refresh_process_state(process_id).await
};
let (process_id, exit_code, event_call_id) = match status {
ProcessStatus::Alive {
exit_code,
@@ -455,7 +507,7 @@ impl UnifiedExecProcessManager {
.map(|session| session.subscribe_out_of_band_elicitation_pause_state());
Ok(PreparedProcessHandles {
writer_tx: entry.process.writer_sender(),
process: Arc::clone(&entry.process),
output_buffer,
output_notify,
output_closed,
@@ -468,16 +520,6 @@ impl UnifiedExecProcessManager {
})
}
async fn send_input(
writer_tx: &mpsc::Sender<Vec<u8>>,
data: &[u8],
) -> Result<(), UnifiedExecError> {
writer_tx
.send(data.to_vec())
.await
.map_err(|_| UnifiedExecError::WriteToStdin)
}
#[allow(clippy::too_many_arguments)]
async fn store_process(
&self,
@@ -539,9 +581,11 @@ impl UnifiedExecProcessManager {
pub(crate) async fn open_session_with_exec_env(
&self,
process_id: i32,
env: &ExecRequest,
tty: bool,
mut spawn_lifecycle: SpawnLifecycleHandle,
environment: &codex_exec_server::Environment,
) -> Result<UnifiedExecProcess, UnifiedExecError> {
let (program, args) = env
.command
@@ -549,6 +593,28 @@ impl UnifiedExecProcessManager {
.ok_or(UnifiedExecError::MissingCommandLine)?;
let inherited_fds = spawn_lifecycle.inherited_fds();
if environment.exec_server_url().is_some() {
if !inherited_fds.is_empty() {
return Err(UnifiedExecError::create_process(
"remote exec-server does not support inherited file descriptors".to_string(),
));
}
let started = environment
.get_exec_backend()
.start(codex_exec_server::ExecParams {
process_id: exec_server_process_id(process_id),
argv: env.command.clone(),
cwd: env.cwd.clone(),
env: env.env.clone(),
tty,
arg0: env.arg0.clone(),
})
.await
.map_err(|err| UnifiedExecError::create_process(err.to_string()))?;
return UnifiedExecProcess::from_remote_started(started, env.sandbox).await;
}
let spawn_result = if tty {
codex_utils_pty::pty::spawn_process_with_inherited_fds(
program,
@@ -611,6 +677,7 @@ impl UnifiedExecProcessManager {
.await;
let req = UnifiedExecToolRequest {
command: request.command.clone(),
process_id: request.process_id,
cwd,
env,
explicit_env_overrides: context.turn.shell_environment_policy.r#set.clone(),
@@ -34,6 +34,11 @@ fn unified_exec_env_overrides_existing_values() {
assert_eq!(env.get("PATH"), Some(&"/usr/bin".to_string()));
}
#[test]
fn exec_server_process_id_matches_unified_exec_process_id() {
assert_eq!(exec_server_process_id(4321), "4321");
}
#[test]
fn pruning_prefers_exited_processes_outside_recently_used() {
let now = Instant::now();
@@ -0,0 +1,24 @@
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub(crate) struct ProcessState {
pub(crate) has_exited: bool,
pub(crate) exit_code: Option<i32>,
pub(crate) failure_message: Option<String>,
}
impl ProcessState {
pub(crate) fn exited(&self, exit_code: Option<i32>) -> Self {
Self {
has_exited: true,
exit_code,
failure_message: self.failure_message.clone(),
}
}
pub(crate) fn failed(&self, message: String) -> Self {
Self {
has_exited: true,
exit_code: self.exit_code,
failure_message: Some(message),
}
}
}
@@ -0,0 +1,142 @@
use super::process::UnifiedExecProcess;
use crate::unified_exec::UnifiedExecError;
use async_trait::async_trait;
use codex_exec_server::ExecProcess;
use codex_exec_server::ExecServerError;
use codex_exec_server::ProcessId;
use codex_exec_server::ReadResponse;
use codex_exec_server::StartedExecProcess;
use codex_exec_server::WriteResponse;
use codex_exec_server::WriteStatus;
use codex_sandboxing::SandboxType;
use pretty_assertions::assert_eq;
use std::collections::VecDeque;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::watch;
use tokio::time::Duration;
struct MockExecProcess {
process_id: ProcessId,
write_response: WriteResponse,
read_responses: Mutex<VecDeque<ReadResponse>>,
wake_tx: watch::Sender<u64>,
}
#[async_trait]
impl ExecProcess for MockExecProcess {
fn process_id(&self) -> &ProcessId {
&self.process_id
}
fn subscribe_wake(&self) -> watch::Receiver<u64> {
self.wake_tx.subscribe()
}
async fn read(
&self,
_after_seq: Option<u64>,
_max_bytes: Option<usize>,
_wait_ms: Option<u64>,
) -> Result<ReadResponse, ExecServerError> {
Ok(self
.read_responses
.lock()
.await
.pop_front()
.unwrap_or(ReadResponse {
chunks: Vec::new(),
next_seq: 1,
exited: false,
exit_code: None,
closed: false,
failure: None,
}))
}
async fn write(&self, _chunk: Vec<u8>) -> Result<WriteResponse, ExecServerError> {
Ok(self.write_response.clone())
}
async fn terminate(&self) -> Result<(), ExecServerError> {
Ok(())
}
}
async fn remote_process(write_status: WriteStatus) -> UnifiedExecProcess {
let (wake_tx, _wake_rx) = watch::channel(0);
let started = StartedExecProcess {
process: Arc::new(MockExecProcess {
process_id: "test-process".to_string().into(),
write_response: WriteResponse {
status: write_status,
},
read_responses: Mutex::new(VecDeque::new()),
wake_tx,
}),
};
UnifiedExecProcess::from_remote_started(started, SandboxType::None)
.await
.expect("remote process should start")
}
#[tokio::test]
async fn remote_write_unknown_process_marks_process_exited() {
let process = remote_process(WriteStatus::UnknownProcess).await;
let err = process
.write(b"hello")
.await
.expect_err("expected write failure");
assert!(matches!(err, UnifiedExecError::WriteToStdin));
assert!(process.has_exited());
}
#[tokio::test]
async fn remote_write_closed_stdin_marks_process_exited() {
let process = remote_process(WriteStatus::StdinClosed).await;
let err = process
.write(b"hello")
.await
.expect_err("expected write failure");
assert!(matches!(err, UnifiedExecError::WriteToStdin));
assert!(process.has_exited());
}
#[tokio::test]
async fn remote_process_waits_for_early_exit_event() {
let (wake_tx, _wake_rx) = watch::channel(0);
let started = StartedExecProcess {
process: Arc::new(MockExecProcess {
process_id: "test-process".to_string().into(),
write_response: WriteResponse {
status: WriteStatus::Accepted,
},
read_responses: Mutex::new(VecDeque::from([ReadResponse {
chunks: Vec::new(),
next_seq: 2,
exited: true,
exit_code: Some(17),
closed: true,
failure: None,
}])),
wake_tx: wake_tx.clone(),
}),
};
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
let _ = wake_tx.send(1);
});
let process = UnifiedExecProcess::from_remote_started(started, SandboxType::None)
.await
.expect("remote process should observe early exit");
assert!(process.has_exited());
assert_eq!(process.exit_code(), Some(17));
}
+1
View File
@@ -15,6 +15,7 @@ path = "src/bin/codex-exec-server.rs"
workspace = true
[dependencies]
arc-swap = { workspace = true }
async-trait = { workspace = true }
base64 = { workspace = true }
clap = { workspace = true, features = ["derive"] }
+395 -14
View File
@@ -1,6 +1,8 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use arc_swap::ArcSwap;
use codex_app_server_protocol::FsCopyParams;
use codex_app_server_protocol::FsCopyResponse;
use codex_app_server_protocol::FsCreateDirectoryParams;
@@ -17,22 +19,25 @@ use codex_app_server_protocol::FsWriteFileParams;
use codex_app_server_protocol::FsWriteFileResponse;
use codex_app_server_protocol::JSONRPCNotification;
use serde_json::Value;
use tokio::sync::broadcast;
use tokio::sync::Mutex;
use tokio::sync::watch;
use tokio::time::timeout;
use tokio_tungstenite::connect_async;
use tracing::debug;
use tracing::warn;
use crate::ProcessId;
use crate::client_api::ExecServerClientConnectOptions;
use crate::client_api::RemoteExecServerConnectArgs;
use crate::connection::JsonRpcConnection;
use crate::process::ExecServerEvent;
use crate::protocol::EXEC_CLOSED_METHOD;
use crate::protocol::EXEC_EXITED_METHOD;
use crate::protocol::EXEC_METHOD;
use crate::protocol::EXEC_OUTPUT_DELTA_METHOD;
use crate::protocol::EXEC_READ_METHOD;
use crate::protocol::EXEC_TERMINATE_METHOD;
use crate::protocol::EXEC_WRITE_METHOD;
use crate::protocol::ExecClosedNotification;
use crate::protocol::ExecExitedNotification;
use crate::protocol::ExecOutputDeltaNotification;
use crate::protocol::ExecParams;
@@ -90,9 +95,29 @@ impl RemoteExecServerConnectArgs {
}
}
pub(crate) struct SessionState {
wake_tx: watch::Sender<u64>,
failure: Mutex<Option<String>>,
}
#[derive(Clone)]
pub(crate) struct Session {
client: ExecServerClient,
process_id: ProcessId,
state: Arc<SessionState>,
}
struct Inner {
client: RpcClient,
events_tx: broadcast::Sender<ExecServerEvent>,
// The remote transport delivers one shared notification stream for every
// process on the connection. Keep a local process_id -> session registry so
// we can turn those connection-global notifications into process wakeups
// without making notifications the source of truth for output delivery.
sessions: ArcSwap<HashMap<String, Arc<SessionState>>>,
// ArcSwap makes reads cheap on the hot notification path, but writes still
// need serialization so concurrent register/remove operations do not
// overwrite each other's copy-on-write updates.
sessions_write_lock: Mutex<()>,
reader_task: tokio::task::JoinHandle<()>,
}
@@ -158,10 +183,6 @@ impl ExecServerClient {
.await
}
pub fn event_receiver(&self) -> broadcast::Receiver<ExecServerEvent> {
self.inner.events_tx.subscribe()
}
pub async fn initialize(
&self,
options: ExecServerClientConnectOptions,
@@ -307,6 +328,25 @@ impl ExecServerClient {
.map_err(Into::into)
}
pub(crate) async fn register_session(
&self,
process_id: &str,
) -> Result<Session, ExecServerError> {
let state = Arc::new(SessionState::new());
self.inner
.insert_session(process_id, Arc::clone(&state))
.await?;
Ok(Session {
client: self.clone(),
process_id: process_id.to_string().into(),
state,
})
}
pub(crate) async fn unregister_session(&self, process_id: &str) {
self.inner.remove_session(process_id).await;
}
async fn connect(
connection: JsonRpcConnection,
options: ExecServerClientConnectOptions,
@@ -322,13 +362,18 @@ impl ExecServerClient {
&& let Err(err) =
handle_server_notification(&inner, notification).await
{
warn!("exec-server client closing after protocol error: {err}");
fail_all_sessions(
&inner,
format!("exec-server notification handling failed: {err}"),
)
.await;
return;
}
}
RpcClientEvent::Disconnected { reason } => {
if let Some(reason) = reason {
warn!("exec-server client transport disconnected: {reason}");
if let Some(inner) = weak.upgrade() {
fail_all_sessions(&inner, disconnected_message(reason.as_deref()))
.await;
}
return;
}
@@ -338,7 +383,8 @@ impl ExecServerClient {
Inner {
client: rpc_client,
events_tx: broadcast::channel(256).0,
sessions: ArcSwap::from_pointee(HashMap::new()),
sessions_write_lock: Mutex::new(()),
reader_task,
}
});
@@ -370,6 +416,177 @@ impl From<RpcCallError> for ExecServerError {
}
}
impl SessionState {
fn new() -> Self {
let (wake_tx, _wake_rx) = watch::channel(0);
Self {
wake_tx,
failure: Mutex::new(None),
}
}
pub(crate) fn subscribe(&self) -> watch::Receiver<u64> {
self.wake_tx.subscribe()
}
fn note_change(&self, seq: u64) {
let next = (*self.wake_tx.borrow()).max(seq);
let _ = self.wake_tx.send(next);
}
async fn set_failure(&self, message: String) {
let mut failure = self.failure.lock().await;
if failure.is_none() {
*failure = Some(message);
}
drop(failure);
let next = (*self.wake_tx.borrow()).saturating_add(1);
let _ = self.wake_tx.send(next);
}
async fn failed_response(&self) -> Option<ReadResponse> {
self.failure
.lock()
.await
.clone()
.map(|message| self.synthesized_failure(message))
}
fn synthesized_failure(&self, message: String) -> ReadResponse {
let next_seq = (*self.wake_tx.borrow()).saturating_add(1);
ReadResponse {
chunks: Vec::new(),
next_seq,
exited: true,
exit_code: None,
closed: true,
failure: Some(message),
}
}
}
impl Session {
pub(crate) fn process_id(&self) -> &ProcessId {
&self.process_id
}
pub(crate) fn subscribe_wake(&self) -> watch::Receiver<u64> {
self.state.subscribe()
}
pub(crate) async fn read(
&self,
after_seq: Option<u64>,
max_bytes: Option<usize>,
wait_ms: Option<u64>,
) -> Result<ReadResponse, ExecServerError> {
if let Some(response) = self.state.failed_response().await {
return Ok(response);
}
match self
.client
.read(ReadParams {
process_id: self.process_id.to_string(),
after_seq,
max_bytes,
wait_ms,
})
.await
{
Ok(response) => Ok(response),
Err(err) if is_transport_closed_error(&err) => {
let message = disconnected_message(/*reason*/ None);
self.state.set_failure(message.clone()).await;
Ok(self.state.synthesized_failure(message))
}
Err(err) => Err(err),
}
}
pub(crate) async fn write(&self, chunk: Vec<u8>) -> Result<WriteResponse, ExecServerError> {
self.client.write(&self.process_id, chunk).await
}
pub(crate) async fn terminate(&self) -> Result<(), ExecServerError> {
self.client.terminate(&self.process_id).await?;
Ok(())
}
pub(crate) async fn unregister(&self) {
self.client.unregister_session(&self.process_id).await;
}
}
impl Inner {
fn get_session(&self, process_id: &str) -> Option<Arc<SessionState>> {
self.sessions.load().get(process_id).cloned()
}
async fn insert_session(
&self,
process_id: &str,
session: Arc<SessionState>,
) -> Result<(), ExecServerError> {
let _sessions_write_guard = self.sessions_write_lock.lock().await;
let sessions = self.sessions.load();
if sessions.contains_key(process_id) {
return Err(ExecServerError::Protocol(format!(
"session already registered for process {process_id}"
)));
}
let mut next_sessions = sessions.as_ref().clone();
next_sessions.insert(process_id.to_string(), session);
self.sessions.store(Arc::new(next_sessions));
Ok(())
}
async fn remove_session(&self, process_id: &str) -> Option<Arc<SessionState>> {
let _sessions_write_guard = self.sessions_write_lock.lock().await;
let sessions = self.sessions.load();
let session = sessions.get(process_id).cloned();
session.as_ref()?;
let mut next_sessions = sessions.as_ref().clone();
next_sessions.remove(process_id);
self.sessions.store(Arc::new(next_sessions));
session
}
async fn take_all_sessions(&self) -> HashMap<String, Arc<SessionState>> {
let _sessions_write_guard = self.sessions_write_lock.lock().await;
let sessions = self.sessions.load();
let drained_sessions = sessions.as_ref().clone();
self.sessions.store(Arc::new(HashMap::new()));
drained_sessions
}
}
fn disconnected_message(reason: Option<&str>) -> String {
match reason {
Some(reason) => format!("exec-server transport disconnected: {reason}"),
None => "exec-server transport disconnected".to_string(),
}
}
fn is_transport_closed_error(error: &ExecServerError) -> bool {
matches!(error, ExecServerError::Closed)
|| matches!(
error,
ExecServerError::Server {
code: -32000,
message,
} if message == "JSON-RPC transport closed"
)
}
async fn fail_all_sessions(inner: &Arc<Inner>, message: String) {
let sessions = inner.take_all_sessions().await;
for (_, session) in sessions {
session.set_failure(message.clone()).await;
}
}
async fn handle_server_notification(
inner: &Arc<Inner>,
notification: JSONRPCNotification,
@@ -378,12 +595,26 @@ async fn handle_server_notification(
EXEC_OUTPUT_DELTA_METHOD => {
let params: ExecOutputDeltaNotification =
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
let _ = inner.events_tx.send(ExecServerEvent::OutputDelta(params));
if let Some(session) = inner.get_session(&params.process_id) {
session.note_change(params.seq);
}
}
EXEC_EXITED_METHOD => {
let params: ExecExitedNotification =
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
let _ = inner.events_tx.send(ExecServerEvent::Exited(params));
if let Some(session) = inner.get_session(&params.process_id) {
session.note_change(params.seq);
}
}
EXEC_CLOSED_METHOD => {
let params: ExecClosedNotification =
serde_json::from_value(notification.params.unwrap_or(Value::Null))?;
// Closed is the terminal lifecycle event for this process, so drop
// the routing entry before forwarding it.
let session = inner.remove_session(&params.process_id).await;
if let Some(session) = session {
session.note_change(params.seq);
}
}
other => {
debug!("ignoring unknown exec-server notification: {other}");
@@ -391,3 +622,153 @@ async fn handle_server_notification(
}
Ok(())
}
#[cfg(test)]
mod tests {
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCResponse;
use pretty_assertions::assert_eq;
use tokio::io::AsyncBufReadExt;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt;
use tokio::io::BufReader;
use tokio::io::duplex;
use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio::time::timeout;
use super::ExecServerClient;
use super::ExecServerClientConnectOptions;
use crate::connection::JsonRpcConnection;
use crate::protocol::EXEC_EXITED_METHOD;
use crate::protocol::EXEC_OUTPUT_DELTA_METHOD;
use crate::protocol::ExecExitedNotification;
use crate::protocol::ExecOutputDeltaNotification;
use crate::protocol::ExecOutputStream;
use crate::protocol::INITIALIZE_METHOD;
use crate::protocol::INITIALIZED_METHOD;
use crate::protocol::InitializeResponse;
async fn read_jsonrpc_line<R>(lines: &mut tokio::io::Lines<BufReader<R>>) -> JSONRPCMessage
where
R: tokio::io::AsyncRead + Unpin,
{
let line = timeout(Duration::from_secs(1), lines.next_line())
.await
.expect("json-rpc read should not time out")
.expect("json-rpc read should succeed")
.expect("json-rpc connection should stay open");
serde_json::from_str(&line).expect("json-rpc line should parse")
}
async fn write_jsonrpc_line<W>(writer: &mut W, message: JSONRPCMessage)
where
W: AsyncWrite + Unpin,
{
let encoded = serde_json::to_string(&message).expect("json-rpc message should serialize");
writer
.write_all(format!("{encoded}\n").as_bytes())
.await
.expect("json-rpc line should write");
}
#[tokio::test]
async fn wake_notifications_do_not_block_other_sessions() {
let (client_stdin, server_reader) = duplex(1 << 20);
let (mut server_writer, client_stdout) = duplex(1 << 20);
let (notifications_tx, mut notifications_rx) = mpsc::channel(16);
let server = tokio::spawn(async move {
let mut lines = BufReader::new(server_reader).lines();
let initialize = read_jsonrpc_line(&mut lines).await;
let request = match initialize {
JSONRPCMessage::Request(request) if request.method == INITIALIZE_METHOD => request,
other => panic!("expected initialize request, got {other:?}"),
};
write_jsonrpc_line(
&mut server_writer,
JSONRPCMessage::Response(JSONRPCResponse {
id: request.id,
result: serde_json::to_value(InitializeResponse {})
.expect("initialize response should serialize"),
}),
)
.await;
let initialized = read_jsonrpc_line(&mut lines).await;
match initialized {
JSONRPCMessage::Notification(notification)
if notification.method == INITIALIZED_METHOD => {}
other => panic!("expected initialized notification, got {other:?}"),
}
while let Some(message) = notifications_rx.recv().await {
write_jsonrpc_line(&mut server_writer, message).await;
}
});
let client = ExecServerClient::connect(
JsonRpcConnection::from_stdio(
client_stdout,
client_stdin,
"test-exec-server-client".to_string(),
),
ExecServerClientConnectOptions::default(),
)
.await
.expect("client should connect");
let _noisy_session = client
.register_session("noisy")
.await
.expect("noisy session should register");
let quiet_session = client
.register_session("quiet")
.await
.expect("quiet session should register");
let mut quiet_wake_rx = quiet_session.subscribe_wake();
for seq in 0..=4096 {
notifications_tx
.send(JSONRPCMessage::Notification(JSONRPCNotification {
method: EXEC_OUTPUT_DELTA_METHOD.to_string(),
params: Some(
serde_json::to_value(ExecOutputDeltaNotification {
process_id: "noisy".to_string(),
seq,
stream: ExecOutputStream::Stdout,
chunk: b"x".to_vec().into(),
})
.expect("output notification should serialize"),
),
}))
.await
.expect("output notification should queue");
}
notifications_tx
.send(JSONRPCMessage::Notification(JSONRPCNotification {
method: EXEC_EXITED_METHOD.to_string(),
params: Some(
serde_json::to_value(ExecExitedNotification {
process_id: "quiet".to_string(),
seq: 1,
exit_code: 17,
})
.expect("exit notification should serialize"),
),
}))
.await
.expect("exit notification should queue");
timeout(Duration::from_secs(1), quiet_wake_rx.changed())
.await
.expect("quiet session should receive wake before timeout")
.expect("quiet wake channel should stay open");
assert_eq!(*quiet_wake_rx.borrow(), 1);
drop(notifications_tx);
drop(client);
server.await.expect("server task should finish");
}
}
+24 -29
View File
@@ -8,14 +8,14 @@ use crate::RemoteExecServerConnectArgs;
use crate::file_system::ExecutorFileSystem;
use crate::local_file_system::LocalFileSystem;
use crate::local_process::LocalProcess;
use crate::process::ExecProcess;
use crate::process::ExecBackend;
use crate::remote_file_system::RemoteFileSystem;
use crate::remote_process::RemoteProcess;
pub const CODEX_EXEC_SERVER_URL_ENV_VAR: &str = "CODEX_EXEC_SERVER_URL";
pub trait ExecutorEnvironment: Send + Sync {
fn get_executor(&self) -> Arc<dyn ExecProcess>;
fn get_exec_backend(&self) -> Arc<dyn ExecBackend>;
}
#[derive(Debug, Default)]
@@ -56,7 +56,7 @@ impl EnvironmentManager {
pub struct Environment {
exec_server_url: Option<String>,
remote_exec_server_client: Option<ExecServerClient>,
executor: Arc<dyn ExecProcess>,
exec_backend: Arc<dyn ExecBackend>,
}
impl Default for Environment {
@@ -72,7 +72,7 @@ impl Default for Environment {
Self {
exec_server_url: None,
remote_exec_server_client: None,
executor: Arc::new(local_process),
exec_backend: Arc::new(local_process),
}
}
}
@@ -102,24 +102,24 @@ impl Environment {
None
};
let executor: Arc<dyn ExecProcess> = if let Some(client) = remote_exec_server_client.clone()
{
Arc::new(RemoteProcess::new(client))
} else {
let local_process = LocalProcess::default();
local_process
.initialize()
.map_err(|err| ExecServerError::Protocol(err.message))?;
local_process
.initialized()
.map_err(ExecServerError::Protocol)?;
Arc::new(local_process)
};
let exec_backend: Arc<dyn ExecBackend> =
if let Some(client) = remote_exec_server_client.clone() {
Arc::new(RemoteProcess::new(client))
} else {
let local_process = LocalProcess::default();
local_process
.initialize()
.map_err(|err| ExecServerError::Protocol(err.message))?;
local_process
.initialized()
.map_err(ExecServerError::Protocol)?;
Arc::new(local_process)
};
Ok(Self {
exec_server_url,
remote_exec_server_client,
executor,
exec_backend,
})
}
@@ -127,8 +127,8 @@ impl Environment {
self.exec_server_url.as_deref()
}
pub fn get_executor(&self) -> Arc<dyn ExecProcess> {
Arc::clone(&self.executor)
pub fn get_exec_backend(&self) -> Arc<dyn ExecBackend> {
Arc::clone(&self.exec_backend)
}
pub fn get_filesystem(&self) -> Arc<dyn ExecutorFileSystem> {
@@ -148,8 +148,8 @@ fn normalize_exec_server_url(exec_server_url: Option<String>) -> Option<String>
}
impl ExecutorEnvironment for Environment {
fn get_executor(&self) -> Arc<dyn ExecProcess> {
Arc::clone(&self.executor)
fn get_exec_backend(&self) -> Arc<dyn ExecBackend> {
Arc::clone(&self.exec_backend)
}
}
@@ -193,7 +193,7 @@ mod tests {
let environment = Environment::default();
let response = environment
.get_executor()
.get_exec_backend()
.start(crate::ExecParams {
process_id: "default-env-proc".to_string(),
argv: vec!["true".to_string()],
@@ -205,11 +205,6 @@ mod tests {
.await
.expect("start process");
assert_eq!(
response,
crate::ExecResponse {
process_id: "default-env-proc".to_string(),
}
);
assert_eq!(response.process.process_id().as_str(), "default-env-proc");
}
}
+5 -1
View File
@@ -41,8 +41,11 @@ pub use file_system::FileMetadata;
pub use file_system::FileSystemResult;
pub use file_system::ReadDirectoryEntry;
pub use file_system::RemoveOptions;
pub use process::ExecBackend;
pub use process::ExecProcess;
pub use process::ExecServerEvent;
pub use process::ProcessId;
pub use process::StartedExecProcess;
pub use protocol::ExecClosedNotification;
pub use protocol::ExecExitedNotification;
pub use protocol::ExecOutputDeltaNotification;
pub use protocol::ExecOutputStream;
@@ -56,6 +59,7 @@ pub use protocol::TerminateParams;
pub use protocol::TerminateResponse;
pub use protocol::WriteParams;
pub use protocol::WriteResponse;
pub use protocol::WriteStatus;
pub use server::DEFAULT_LISTEN_URL;
pub use server::ExecServerListenUrlParseError;
pub use server::run_main;
+192 -53
View File
@@ -11,13 +11,16 @@ use codex_utils_pty::ExecCommandSession;
use codex_utils_pty::TerminalSize;
use tokio::sync::Mutex;
use tokio::sync::Notify;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tracing::warn;
use tokio::sync::watch;
use crate::ExecBackend;
use crate::ExecProcess;
use crate::ExecServerError;
use crate::ExecServerEvent;
use crate::ProcessId;
use crate::StartedExecProcess;
use crate::protocol::EXEC_CLOSED_METHOD;
use crate::protocol::ExecClosedNotification;
use crate::protocol::ExecExitedNotification;
use crate::protocol::ExecOutputDeltaNotification;
use crate::protocol::ExecOutputStream;
@@ -31,6 +34,7 @@ use crate::protocol::TerminateParams;
use crate::protocol::TerminateResponse;
use crate::protocol::WriteParams;
use crate::protocol::WriteResponse;
use crate::protocol::WriteStatus;
use crate::rpc::RpcNotificationSender;
use crate::rpc::RpcServerOutboundMessage;
use crate::rpc::internal_error;
@@ -38,7 +42,6 @@ use crate::rpc::invalid_params;
use crate::rpc::invalid_request;
const RETAINED_OUTPUT_BYTES_PER_PROCESS: usize = 1024 * 1024;
const EVENT_CHANNEL_CAPACITY: usize = 256;
const NOTIFICATION_CHANNEL_CAPACITY: usize = 256;
#[cfg(test)]
const EXITED_PROCESS_RETENTION: Duration = Duration::from_millis(25);
@@ -59,7 +62,10 @@ struct RunningProcess {
retained_bytes: usize,
next_seq: u64,
exit_code: Option<i32>,
wake_tx: watch::Sender<u64>,
output_notify: Arc<Notify>,
open_streams: usize,
closed: bool,
}
enum ProcessEntry {
@@ -69,7 +75,6 @@ enum ProcessEntry {
struct Inner {
notifications: RpcNotificationSender,
events_tx: broadcast::Sender<ExecServerEvent>,
processes: Mutex<HashMap<String, ProcessEntry>>,
initialize_requested: AtomicBool,
initialized: AtomicBool,
@@ -80,6 +85,12 @@ pub(crate) struct LocalProcess {
inner: Arc<Inner>,
}
struct LocalExecProcess {
process_id: ProcessId,
backend: LocalProcess,
wake_tx: watch::Sender<u64>,
}
impl Default for LocalProcess {
fn default() -> Self {
let (outgoing_tx, mut outgoing_rx) =
@@ -94,7 +105,6 @@ impl LocalProcess {
Self {
inner: Arc::new(Inner {
notifications,
events_tx: broadcast::channel(EVENT_CHANNEL_CAPACITY).0,
processes: Mutex::new(HashMap::new()),
initialize_requested: AtomicBool::new(false),
initialized: AtomicBool::new(false),
@@ -152,10 +162,12 @@ impl LocalProcess {
Ok(())
}
pub(crate) async fn exec(&self, params: ExecParams) -> Result<ExecResponse, JSONRPCErrorError> {
async fn start_process(
&self,
params: ExecParams,
) -> Result<(ExecResponse, watch::Sender<u64>), JSONRPCErrorError> {
self.require_initialized_for("exec")?;
let process_id = params.process_id.clone();
let (program, args) = params
.argv
.split_first()
@@ -203,6 +215,7 @@ impl LocalProcess {
};
let output_notify = Arc::new(Notify::new());
let (wake_tx, _wake_rx) = watch::channel(0);
{
let mut process_map = self.inner.processes.lock().await;
process_map.insert(
@@ -214,7 +227,10 @@ impl LocalProcess {
retained_bytes: 0,
next_seq: 1,
exit_code: None,
wake_tx: wake_tx.clone(),
output_notify: Arc::clone(&output_notify),
open_streams: 2,
closed: false,
})),
);
}
@@ -248,7 +264,13 @@ impl LocalProcess {
output_notify,
));
Ok(ExecResponse { process_id })
Ok((ExecResponse { process_id }, wake_tx))
}
pub(crate) async fn exec(&self, params: ExecParams) -> Result<ExecResponse, JSONRPCErrorError> {
self.start_process(params)
.await
.map(|(response, _)| response)
}
pub(crate) async fn exec_read(
@@ -256,6 +278,7 @@ impl LocalProcess {
params: ReadParams,
) -> Result<ReadResponse, JSONRPCErrorError> {
self.require_initialized_for("exec")?;
let _process_id = params.process_id.clone();
let after_seq = params.after_seq.unwrap_or(0);
let max_bytes = params.max_bytes.unwrap_or(usize::MAX);
let wait = Duration::from_millis(params.wait_ms.unwrap_or(0));
@@ -300,6 +323,8 @@ impl LocalProcess {
next_seq,
exited: process.exit_code.is_some(),
exit_code: process.exit_code,
closed: process.closed,
failure: None,
},
Arc::clone(&process.output_notify),
)
@@ -309,6 +334,11 @@ impl LocalProcess {
|| response.exited
|| tokio::time::Instant::now() >= deadline
{
let _total_bytes: usize = response
.chunks
.iter()
.map(|chunk| chunk.chunk.0.len())
.sum();
return Ok(response);
}
@@ -325,22 +355,24 @@ impl LocalProcess {
params: WriteParams,
) -> Result<WriteResponse, JSONRPCErrorError> {
self.require_initialized_for("exec")?;
let _process_id = params.process_id.clone();
let _input_bytes = params.chunk.0.len();
let writer_tx = {
let process_map = self.inner.processes.lock().await;
let process = process_map.get(&params.process_id).ok_or_else(|| {
invalid_request(format!("unknown process id {}", params.process_id))
})?;
let Some(process) = process_map.get(&params.process_id) else {
return Ok(WriteResponse {
status: WriteStatus::UnknownProcess,
});
};
let ProcessEntry::Running(process) = process else {
return Err(invalid_request(format!(
"process id {} is starting",
params.process_id
)));
return Ok(WriteResponse {
status: WriteStatus::Starting,
});
};
if !process.tty {
return Err(invalid_request(format!(
"stdin is closed for process {}",
params.process_id
)));
return Ok(WriteResponse {
status: WriteStatus::StdinClosed,
});
}
process.session.writer_sender()
};
@@ -350,7 +382,9 @@ impl LocalProcess {
.await
.map_err(|_| internal_error("failed to write to process stdin".to_string()))?;
Ok(WriteResponse { accepted: true })
Ok(WriteResponse {
status: WriteStatus::Accepted,
})
}
pub(crate) async fn terminate_process(
@@ -358,6 +392,7 @@ impl LocalProcess {
params: TerminateParams,
) -> Result<TerminateResponse, JSONRPCErrorError> {
self.require_initialized_for("exec")?;
let _process_id = params.process_id.clone();
let running = {
let process_map = self.inner.processes.lock().await;
match process_map.get(&params.process_id) {
@@ -377,13 +412,68 @@ impl LocalProcess {
}
#[async_trait]
impl ExecProcess for LocalProcess {
async fn start(&self, params: ExecParams) -> Result<ExecResponse, ExecServerError> {
self.exec(params).await.map_err(map_handler_error)
impl ExecBackend for LocalProcess {
async fn start(&self, params: ExecParams) -> Result<StartedExecProcess, ExecServerError> {
let (response, wake_tx) = self
.start_process(params)
.await
.map_err(map_handler_error)?;
Ok(StartedExecProcess {
process: Arc::new(LocalExecProcess {
process_id: response.process_id.into(),
backend: self.clone(),
wake_tx,
}),
})
}
}
#[async_trait]
impl ExecProcess for LocalExecProcess {
fn process_id(&self) -> &ProcessId {
&self.process_id
}
async fn read(&self, params: ReadParams) -> Result<ReadResponse, ExecServerError> {
self.exec_read(params).await.map_err(map_handler_error)
fn subscribe_wake(&self) -> watch::Receiver<u64> {
self.wake_tx.subscribe()
}
async fn read(
&self,
after_seq: Option<u64>,
max_bytes: Option<usize>,
wait_ms: Option<u64>,
) -> Result<ReadResponse, ExecServerError> {
self.backend
.read(&self.process_id, after_seq, max_bytes, wait_ms)
.await
}
async fn write(&self, chunk: Vec<u8>) -> Result<WriteResponse, ExecServerError> {
self.backend.write(&self.process_id, chunk).await
}
async fn terminate(&self) -> Result<(), ExecServerError> {
self.backend.terminate(&self.process_id).await
}
}
impl LocalProcess {
async fn read(
&self,
process_id: &str,
after_seq: Option<u64>,
max_bytes: Option<usize>,
wait_ms: Option<u64>,
) -> Result<ReadResponse, ExecServerError> {
self.exec_read(ReadParams {
process_id: process_id.to_string(),
after_seq,
max_bytes,
wait_ms,
})
.await
.map_err(map_handler_error)
}
async fn write(
@@ -399,16 +489,13 @@ impl ExecProcess for LocalProcess {
.map_err(map_handler_error)
}
async fn terminate(&self, process_id: &str) -> Result<TerminateResponse, ExecServerError> {
async fn terminate(&self, process_id: &str) -> Result<(), ExecServerError> {
self.terminate_process(TerminateParams {
process_id: process_id.to_string(),
})
.await
.map_err(map_handler_error)
}
fn subscribe_events(&self) -> broadcast::Receiver<ExecServerEvent> {
self.inner.events_tx.subscribe()
.map_err(map_handler_error)?;
Ok(())
}
}
@@ -427,6 +514,7 @@ async fn stream_output(
output_notify: Arc<Notify>,
) {
while let Some(chunk) = receiver.recv().await {
let _chunk_len = chunk.len();
let notification = {
let mut processes = inner.processes.lock().await;
let Some(entry) = processes.get_mut(&process_id) else {
@@ -448,21 +536,16 @@ async fn stream_output(
break;
};
process.retained_bytes = process.retained_bytes.saturating_sub(evicted.chunk.len());
warn!(
"retained output cap exceeded for process {process_id}; dropping oldest output"
);
}
let _ = process.wake_tx.send(seq);
ExecOutputDeltaNotification {
process_id: process_id.clone(),
seq,
stream,
chunk: chunk.into(),
}
};
output_notify.notify_waiters();
let _ = inner
.events_tx
.send(ExecServerEvent::OutputDelta(notification.clone()));
if inner
.notifications
.notify(crate::protocol::EXEC_OUTPUT_DELTA_METHOD, &notification)
@@ -472,6 +555,8 @@ async fn stream_output(
break;
}
}
finish_output_stream(process_id, inner).await;
}
async fn watch_exit(
@@ -481,29 +566,35 @@ async fn watch_exit(
output_notify: Arc<Notify>,
) {
let exit_code = exit_rx.await.unwrap_or(-1);
{
let notification = {
let mut processes = inner.processes.lock().await;
if let Some(ProcessEntry::Running(process)) = processes.get_mut(&process_id) {
let seq = process.next_seq;
process.next_seq += 1;
process.exit_code = Some(exit_code);
let _ = process.wake_tx.send(seq);
Some(ExecExitedNotification {
process_id: process_id.clone(),
seq,
exit_code,
})
} else {
None
}
}
output_notify.notify_waiters();
let notification = ExecExitedNotification {
process_id: process_id.clone(),
exit_code,
};
let _ = inner
.events_tx
.send(ExecServerEvent::Exited(notification.clone()));
if inner
.notifications
.notify(crate::protocol::EXEC_EXITED_METHOD, &notification)
.await
.is_err()
output_notify.notify_waiters();
if let Some(notification) = notification
&& inner
.notifications
.notify(crate::protocol::EXEC_EXITED_METHOD, &notification)
.await
.is_err()
{
return;
}
maybe_emit_closed(process_id.clone(), Arc::clone(&inner)).await;
tokio::time::sleep(EXITED_PROCESS_RETENTION).await;
let mut processes = inner.processes.lock().await;
if matches!(
@@ -513,3 +604,51 @@ async fn watch_exit(
processes.remove(&process_id);
}
}
async fn finish_output_stream(process_id: String, inner: Arc<Inner>) {
{
let mut processes = inner.processes.lock().await;
let Some(ProcessEntry::Running(process)) = processes.get_mut(&process_id) else {
return;
};
if process.open_streams > 0 {
process.open_streams -= 1;
}
}
maybe_emit_closed(process_id, inner).await;
}
async fn maybe_emit_closed(process_id: String, inner: Arc<Inner>) {
let notification = {
let mut processes = inner.processes.lock().await;
let Some(ProcessEntry::Running(process)) = processes.get_mut(&process_id) else {
return;
};
if process.closed || process.open_streams != 0 || process.exit_code.is_none() {
return;
}
process.closed = true;
let seq = process.next_seq;
process.next_seq += 1;
let _ = process.wake_tx.send(seq);
Some(ExecClosedNotification {
process_id: process_id.clone(),
seq,
})
};
let Some(notification) = notification else {
return;
};
if inner
.notifications
.notify(EXEC_CLOSED_METHOD, &notification)
.await
.is_err()
{}
}
+60 -18
View File
@@ -1,35 +1,77 @@
use std::fmt;
use std::ops::Deref;
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::broadcast;
use tokio::sync::watch;
use crate::ExecServerError;
use crate::protocol::ExecExitedNotification;
use crate::protocol::ExecOutputDeltaNotification;
use crate::protocol::ExecParams;
use crate::protocol::ExecResponse;
use crate::protocol::ReadParams;
use crate::protocol::ReadResponse;
use crate::protocol::TerminateResponse;
use crate::protocol::WriteResponse;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExecServerEvent {
OutputDelta(ExecOutputDeltaNotification),
Exited(ExecExitedNotification),
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ProcessId(String);
pub struct StartedExecProcess {
pub process: Arc<dyn ExecProcess>,
}
impl ProcessId {
pub fn as_str(&self) -> &str {
&self.0
}
pub fn into_inner(self) -> String {
self.0
}
}
impl Deref for ProcessId {
type Target = str;
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
impl AsRef<str> for ProcessId {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl fmt::Display for ProcessId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl From<String> for ProcessId {
fn from(value: String) -> Self {
Self(value)
}
}
#[async_trait]
pub trait ExecProcess: Send + Sync {
async fn start(&self, params: ExecParams) -> Result<ExecResponse, ExecServerError>;
fn process_id(&self) -> &ProcessId;
async fn read(&self, params: ReadParams) -> Result<ReadResponse, ExecServerError>;
fn subscribe_wake(&self) -> watch::Receiver<u64>;
async fn write(
async fn read(
&self,
process_id: &str,
chunk: Vec<u8>,
) -> Result<WriteResponse, ExecServerError>;
after_seq: Option<u64>,
max_bytes: Option<usize>,
wait_ms: Option<u64>,
) -> Result<ReadResponse, ExecServerError>;
async fn terminate(&self, process_id: &str) -> Result<TerminateResponse, ExecServerError>;
async fn write(&self, chunk: Vec<u8>) -> Result<WriteResponse, ExecServerError>;
fn subscribe_events(&self) -> broadcast::Receiver<ExecServerEvent>;
async fn terminate(&self) -> Result<(), ExecServerError>;
}
#[async_trait]
pub trait ExecBackend: Send + Sync {
async fn start(&self, params: ExecParams) -> Result<StartedExecProcess, ExecServerError>;
}
+22 -1
View File
@@ -13,6 +13,7 @@ pub const EXEC_WRITE_METHOD: &str = "process/write";
pub const EXEC_TERMINATE_METHOD: &str = "process/terminate";
pub const EXEC_OUTPUT_DELTA_METHOD: &str = "process/output";
pub const EXEC_EXITED_METHOD: &str = "process/exited";
pub const EXEC_CLOSED_METHOD: &str = "process/closed";
pub const FS_READ_FILE_METHOD: &str = "fs/readFile";
pub const FS_WRITE_FILE_METHOD: &str = "fs/writeFile";
pub const FS_CREATE_DIRECTORY_METHOD: &str = "fs/createDirectory";
@@ -90,6 +91,8 @@ pub struct ReadResponse {
pub next_seq: u64,
pub exited: bool,
pub exit_code: Option<i32>,
pub closed: bool,
pub failure: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@@ -99,10 +102,19 @@ pub struct WriteParams {
pub chunk: ByteChunk,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum WriteStatus {
Accepted,
UnknownProcess,
StdinClosed,
Starting,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct WriteResponse {
pub accepted: bool,
pub status: WriteStatus,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@@ -129,6 +141,7 @@ pub enum ExecOutputStream {
#[serde(rename_all = "camelCase")]
pub struct ExecOutputDeltaNotification {
pub process_id: String,
pub seq: u64,
pub stream: ExecOutputStream,
pub chunk: ByteChunk,
}
@@ -137,9 +150,17 @@ pub struct ExecOutputDeltaNotification {
#[serde(rename_all = "camelCase")]
pub struct ExecExitedNotification {
pub process_id: String,
pub seq: u64,
pub exit_code: i32,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ExecClosedNotification {
pub process_id: String,
pub seq: u64,
}
mod base64_bytes {
use super::BASE64_STANDARD;
use base64::Engine as _;
+61 -33
View File
@@ -1,16 +1,17 @@
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::broadcast;
use tokio::sync::watch;
use tracing::trace;
use crate::ExecBackend;
use crate::ExecProcess;
use crate::ExecServerClient;
use crate::ExecServerError;
use crate::ExecServerEvent;
use crate::StartedExecProcess;
use crate::client::ExecServerClient;
use crate::client::Session;
use crate::protocol::ExecParams;
use crate::protocol::ExecResponse;
use crate::protocol::ReadParams;
use crate::protocol::ReadResponse;
use crate::protocol::TerminateResponse;
use crate::protocol::WriteResponse;
#[derive(Clone)]
@@ -18,6 +19,10 @@ pub(crate) struct RemoteProcess {
client: ExecServerClient,
}
struct RemoteExecProcess {
session: Session,
}
impl RemoteProcess {
pub(crate) fn new(client: ExecServerClient) -> Self {
trace!("remote process new");
@@ -26,33 +31,56 @@ impl RemoteProcess {
}
#[async_trait]
impl ExecProcess for RemoteProcess {
async fn start(&self, params: ExecParams) -> Result<ExecResponse, ExecServerError> {
trace!("remote process start");
self.client.exec(params).await
}
impl ExecBackend for RemoteProcess {
async fn start(&self, params: ExecParams) -> Result<StartedExecProcess, ExecServerError> {
let process_id = params.process_id.clone();
let session = self.client.register_session(&process_id).await?;
if let Err(err) = self.client.exec(params).await {
session.unregister().await;
return Err(err);
}
async fn read(&self, params: ReadParams) -> Result<ReadResponse, ExecServerError> {
trace!("remote process read");
self.client.read(params).await
}
async fn write(
&self,
process_id: &str,
chunk: Vec<u8>,
) -> Result<WriteResponse, ExecServerError> {
trace!("remote process write");
self.client.write(process_id, chunk).await
}
async fn terminate(&self, process_id: &str) -> Result<TerminateResponse, ExecServerError> {
trace!("remote process terminate");
self.client.terminate(process_id).await
}
fn subscribe_events(&self) -> broadcast::Receiver<ExecServerEvent> {
trace!("remote process subscribe_events");
self.client.event_receiver()
Ok(StartedExecProcess {
process: Arc::new(RemoteExecProcess { session }),
})
}
}
#[async_trait]
impl ExecProcess for RemoteExecProcess {
fn process_id(&self) -> &crate::ProcessId {
self.session.process_id()
}
fn subscribe_wake(&self) -> watch::Receiver<u64> {
self.session.subscribe_wake()
}
async fn read(
&self,
after_seq: Option<u64>,
max_bytes: Option<usize>,
wait_ms: Option<u64>,
) -> Result<ReadResponse, ExecServerError> {
self.session.read(after_seq, max_bytes, wait_ms).await
}
async fn write(&self, chunk: Vec<u8>) -> Result<WriteResponse, ExecServerError> {
trace!("exec process write");
self.session.write(chunk).await
}
async fn terminate(&self) -> Result<(), ExecServerError> {
trace!("exec process terminate");
self.session.terminate().await
}
}
impl Drop for RemoteExecProcess {
fn drop(&mut self) {
let session = self.session.clone();
tokio::spawn(async move {
session.unregister().await;
});
}
}
+2 -3
View File
@@ -19,7 +19,6 @@ use tokio::sync::Mutex;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tracing::warn;
use crate::connection::JsonRpcConnection;
use crate::connection::JsonRpcConnectionEvent;
@@ -192,12 +191,12 @@ impl RpcClient {
if let Err(err) =
handle_server_message(&pending_for_reader, &event_tx, message).await
{
warn!("JSON-RPC client closing after protocol error: {err}");
let _ = err;
break;
}
}
JsonRpcConnectionEvent::MalformedMessage { reason } => {
warn!("JSON-RPC client closing after malformed message: {reason}");
let _ = reason;
break;
}
JsonRpcConnectionEvent::Disconnected { reason } => {
+221 -29
View File
@@ -6,19 +6,23 @@ use std::sync::Arc;
use anyhow::Result;
use codex_exec_server::Environment;
use codex_exec_server::ExecBackend;
use codex_exec_server::ExecParams;
use codex_exec_server::ExecProcess;
use codex_exec_server::ExecResponse;
use codex_exec_server::ReadParams;
use codex_exec_server::ReadResponse;
use codex_exec_server::StartedExecProcess;
use pretty_assertions::assert_eq;
use test_case::test_case;
use tokio::sync::watch;
use tokio::time::Duration;
use tokio::time::timeout;
use common::exec_server::ExecServerHarness;
use common::exec_server::exec_server;
struct ProcessContext {
process: Arc<dyn ExecProcess>,
_server: Option<ExecServerHarness>,
backend: Arc<dyn ExecBackend>,
server: Option<ExecServerHarness>,
}
async fn create_process_context(use_remote: bool) -> Result<ProcessContext> {
@@ -26,22 +30,22 @@ async fn create_process_context(use_remote: bool) -> Result<ProcessContext> {
let server = exec_server().await?;
let environment = Environment::create(Some(server.websocket_url().to_string())).await?;
Ok(ProcessContext {
process: environment.get_executor(),
_server: Some(server),
backend: environment.get_exec_backend(),
server: Some(server),
})
} else {
let environment = Environment::create(/*exec_server_url*/ None).await?;
Ok(ProcessContext {
process: environment.get_executor(),
_server: None,
backend: environment.get_exec_backend(),
server: None,
})
}
}
async fn assert_exec_process_starts_and_exits(use_remote: bool) -> Result<()> {
let context = create_process_context(use_remote).await?;
let response = context
.process
let session = context
.backend
.start(ExecParams {
process_id: "proc-1".to_string(),
argv: vec!["true".to_string()],
@@ -51,30 +55,197 @@ async fn assert_exec_process_starts_and_exits(use_remote: bool) -> Result<()> {
arg0: None,
})
.await?;
assert_eq!(
response,
ExecResponse {
process_id: "proc-1".to_string(),
}
);
assert_eq!(session.process.process_id().as_str(), "proc-1");
let wake_rx = session.process.subscribe_wake();
let (_, exit_code, closed) =
collect_process_output_from_reads(session.process, wake_rx).await?;
let mut next_seq = 0;
assert_eq!(exit_code, Some(0));
assert!(closed);
Ok(())
}
async fn read_process_until_change(
session: Arc<dyn ExecProcess>,
wake_rx: &mut watch::Receiver<u64>,
after_seq: Option<u64>,
) -> Result<ReadResponse> {
let response = session
.read(after_seq, /*max_bytes*/ None, /*wait_ms*/ Some(0))
.await?;
if !response.chunks.is_empty() || response.closed || response.failure.is_some() {
return Ok(response);
}
timeout(Duration::from_secs(2), wake_rx.changed()).await??;
session
.read(after_seq, /*max_bytes*/ None, /*wait_ms*/ Some(0))
.await
.map_err(Into::into)
}
async fn collect_process_output_from_reads(
session: Arc<dyn ExecProcess>,
mut wake_rx: watch::Receiver<u64>,
) -> Result<(String, Option<i32>, bool)> {
let mut output = String::new();
let mut exit_code = None;
let mut after_seq = None;
loop {
let read = context
.process
.read(ReadParams {
process_id: "proc-1".to_string(),
after_seq: Some(next_seq),
max_bytes: None,
wait_ms: Some(100),
})
.await?;
next_seq = read.next_seq;
if read.exited {
assert_eq!(read.exit_code, Some(0));
let response =
read_process_until_change(Arc::clone(&session), &mut wake_rx, after_seq).await?;
if let Some(message) = response.failure {
anyhow::bail!("process failed before closed state: {message}");
}
for chunk in response.chunks {
output.push_str(&String::from_utf8_lossy(&chunk.chunk.into_inner()));
after_seq = Some(chunk.seq);
}
if response.exited {
exit_code = response.exit_code;
}
if response.closed {
break;
}
after_seq = response.next_seq.checked_sub(1).or(after_seq);
}
drop(session);
Ok((output, exit_code, true))
}
async fn assert_exec_process_streams_output(use_remote: bool) -> Result<()> {
let context = create_process_context(use_remote).await?;
let process_id = "proc-stream".to_string();
let session = context
.backend
.start(ExecParams {
process_id: process_id.clone(),
argv: vec![
"/bin/sh".to_string(),
"-c".to_string(),
"sleep 0.05; printf 'session output\\n'".to_string(),
],
cwd: std::env::current_dir()?,
env: Default::default(),
tty: false,
arg0: None,
})
.await?;
assert_eq!(session.process.process_id().as_str(), process_id);
let StartedExecProcess { process } = session;
let wake_rx = process.subscribe_wake();
let (output, exit_code, closed) = collect_process_output_from_reads(process, wake_rx).await?;
assert_eq!(output, "session output\n");
assert_eq!(exit_code, Some(0));
assert!(closed);
Ok(())
}
async fn assert_exec_process_write_then_read(use_remote: bool) -> Result<()> {
let context = create_process_context(use_remote).await?;
let process_id = "proc-stdin".to_string();
let session = context
.backend
.start(ExecParams {
process_id: process_id.clone(),
argv: vec![
"/usr/bin/python3".to_string(),
"-c".to_string(),
"import sys; line = sys.stdin.readline(); sys.stdout.write(f'from-stdin:{line}'); sys.stdout.flush()".to_string(),
],
cwd: std::env::current_dir()?,
env: Default::default(),
tty: true,
arg0: None,
})
.await?;
assert_eq!(session.process.process_id().as_str(), process_id);
tokio::time::sleep(Duration::from_millis(200)).await;
session.process.write(b"hello\n".to_vec()).await?;
let StartedExecProcess { process } = session;
let wake_rx = process.subscribe_wake();
let (output, exit_code, closed) = collect_process_output_from_reads(process, wake_rx).await?;
assert!(
output.contains("from-stdin:hello"),
"unexpected output: {output:?}"
);
assert_eq!(exit_code, Some(0));
assert!(closed);
Ok(())
}
async fn assert_exec_process_preserves_queued_events_before_subscribe(
use_remote: bool,
) -> Result<()> {
let context = create_process_context(use_remote).await?;
let session = context
.backend
.start(ExecParams {
process_id: "proc-queued".to_string(),
argv: vec![
"/bin/sh".to_string(),
"-c".to_string(),
"printf 'queued output\\n'".to_string(),
],
cwd: std::env::current_dir()?,
env: Default::default(),
tty: false,
arg0: None,
})
.await?;
tokio::time::sleep(Duration::from_millis(200)).await;
let StartedExecProcess { process } = session;
let wake_rx = process.subscribe_wake();
let (output, exit_code, closed) = collect_process_output_from_reads(process, wake_rx).await?;
assert_eq!(output, "queued output\n");
assert_eq!(exit_code, Some(0));
assert!(closed);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_exec_process_reports_transport_disconnect() -> Result<()> {
let mut context = create_process_context(/*use_remote*/ true).await?;
let session = context
.backend
.start(ExecParams {
process_id: "proc-disconnect".to_string(),
argv: vec![
"/bin/sh".to_string(),
"-c".to_string(),
"sleep 10".to_string(),
],
cwd: std::env::current_dir()?,
env: Default::default(),
tty: false,
arg0: None,
})
.await?;
let server = context
.server
.as_mut()
.expect("remote context should include exec-server harness");
server.shutdown().await?;
let mut wake_rx = session.process.subscribe_wake();
let response = read_process_until_change(session.process, &mut wake_rx, None).await?;
let message = response
.failure
.expect("disconnect should surface as a failure");
assert!(
message.starts_with("exec-server transport disconnected"),
"unexpected failure message: {message}"
);
assert!(
response.closed,
"disconnect should close the process session"
);
Ok(())
}
@@ -85,3 +256,24 @@ async fn assert_exec_process_starts_and_exits(use_remote: bool) -> Result<()> {
async fn exec_process_starts_and_exits(use_remote: bool) -> Result<()> {
assert_exec_process_starts_and_exits(use_remote).await
}
#[test_case(false ; "local")]
#[test_case(true ; "remote")]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn exec_process_streams_output(use_remote: bool) -> Result<()> {
assert_exec_process_streams_output(use_remote).await
}
#[test_case(false ; "local")]
#[test_case(true ; "remote")]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn exec_process_write_then_read(use_remote: bool) -> Result<()> {
assert_exec_process_write_then_read(use_remote).await
}
#[test_case(false ; "local")]
#[test_case(true ; "remote")]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn exec_process_preserves_queued_events_before_subscribe(use_remote: bool) -> Result<()> {
assert_exec_process_preserves_queued_events_before_subscribe(use_remote).await
}