diff --git a/codex-rs/shell-command/src/command_safety/mod.rs b/codex-rs/shell-command/src/command_safety/mod.rs index caf5c9f6e..12e467bdb 100644 --- a/codex-rs/shell-command/src/command_safety/mod.rs +++ b/codex-rs/shell-command/src/command_safety/mod.rs @@ -1,3 +1,5 @@ +mod powershell_parser; + pub mod is_dangerous_command; pub mod is_safe_command; pub mod windows_safe_commands; diff --git a/codex-rs/shell-command/src/command_safety/powershell_parser.ps1 b/codex-rs/shell-command/src/command_safety/powershell_parser.ps1 index af71cb7f3..9f19f172e 100644 --- a/codex-rs/shell-command/src/command_safety/powershell_parser.ps1 +++ b/codex-rs/shell-command/src/command_safety/powershell_parser.ps1 @@ -1,44 +1,98 @@ $ErrorActionPreference = 'Stop' +$ProgressPreference = 'SilentlyContinue' -$payload = $env:CODEX_POWERSHELL_PAYLOAD -if ([string]::IsNullOrEmpty($payload)) { - Write-Output '{"status":"parse_failed"}' - exit 0 -} +# Long-lived PowerShell AST parser used by the Rust command-safety layer on Windows. +# The caller starts one child process per PowerShell executable variant and then sends +# newline-delimited JSON requests over stdin: +# { "id": , "payload": "" } +# We answer with one compact JSON line per request: +# { "id": , "status": "ok", "commands": [["Get-Content", "foo.txt"]] } +# or: +# { "id": , "status": "parse_failed" | "parse_errors" | "unsupported" } +# +# "unsupported" is intentional: it means the script parsed successfully, but the AST +# included constructs that we conservatively refuse to lower into argv-like command words. +# The Rust side treats that the same way as an unsafe command. -try { - $source = - [System.Text.Encoding]::Unicode.GetString( - [System.Convert]::FromBase64String($payload) +# Use BOM-free UTF-8 on the protocol stream so Rust sees clean JSON lines with no +# leading BOM bytes on the first response. +$utf8 = [System.Text.UTF8Encoding]::new($false) +$stdin = [System.IO.StreamReader]::new([Console]::OpenStandardInput(), $utf8, $false) +$stdout = [System.IO.StreamWriter]::new([Console]::OpenStandardOutput(), $utf8) +$stdout.AutoFlush = $true + +function Invoke-ParseRequest { + param($RequestId, $Source) + + $tokens = $null + $errors = $null + + $ast = $null + try { + $ast = [System.Management.Automation.Language.Parser]::ParseInput( + $Source, + [ref]$tokens, + [ref]$errors ) -} catch { - Write-Output '{"status":"parse_failed"}' - exit 0 + } catch { + return @{ id = $RequestId; status = 'parse_failed' } + } + + if ($errors.Count -gt 0) { + return @{ id = $RequestId; status = 'parse_errors' } + } + + # Only accept AST shapes we can flatten into a list of argv-like command words. + # Anything more dynamic than that becomes "unsupported" instead of being guessed at. + $commands = [System.Collections.ArrayList]::new() + + foreach ($statement in $ast.EndBlock.Statements) { + if (-not (Add-CommandsFromPipelineBase $statement $commands)) { + $commands = $null + break + } + } + + if ($commands -ne $null) { + $normalized = [System.Collections.ArrayList]::new() + foreach ($cmd in $commands) { + # Convert every successful parse result to an array-of-arrays shape so the Rust + # side can deserialize one uniform representation. + if ($cmd -is [string]) { + $null = $normalized.Add(@($cmd)) + continue + } + + if ($cmd -is [System.Array] -or $cmd -is [System.Collections.IEnumerable]) { + $null = $normalized.Add(@($cmd)) + continue + } + + $normalized = $null + break + } + + $commands = $normalized + } + + if ($commands -eq $null) { + return @{ id = $RequestId; status = 'unsupported' } + } + + return @{ id = $RequestId; status = 'ok'; commands = $commands } } -$tokens = $null -$errors = $null +function Write-Response { + param($Response) -$ast = $null -try { - $ast = [System.Management.Automation.Language.Parser]::ParseInput( - $source, - [ref]$tokens, - [ref]$errors - ) -} catch { - Write-Output '{"status":"parse_failed"}' - exit 0 -} - -if ($errors.Count -gt 0) { - Write-Output '{"status":"parse_errors"}' - exit 0 + $stdout.WriteLine(($Response | ConvertTo-Json -Compress -Depth 3)) } function Convert-CommandElement { param($element) + # Accept only literal-ish command elements. Variable expansion, subexpressions, splats, + # and other dynamic forms return $null so the whole request becomes unsupported. if ($element -is [System.Management.Automation.Language.StringConstantExpressionAst]) { return @($element.Value) } @@ -77,6 +131,8 @@ function Convert-PipelineElement { param($element) if ($element -is [System.Management.Automation.Language.CommandAst]) { + # Redirections and invocation operators make the command harder to classify safely, + # so reject them rather than trying to normalize them. if ($element.Redirections.Count -gt 0) { return $null } @@ -104,6 +160,8 @@ function Convert-PipelineElement { return $null } + # Allow a parenthesized single pipeline element like "(Get-Content foo.rs -Raw)" so + # the caller still sees the inner command words. More complex expressions stay unsupported. if ($element.Expression -is [System.Management.Automation.Language.ParenExpressionAst]) { $innerPipeline = $element.Expression.Pipeline if ($innerPipeline -and $innerPipeline.PipelineElements.Count -eq 1) { @@ -156,46 +214,44 @@ function Add-CommandsFromPipelineBase { return Add-CommandsFromPipelineAst $pipeline $commands } - if ($pipeline -is [System.Management.Automation.Language.PipelineChainAst]) { + # Windows PowerShell 5.1 does not define PipelineChainAst, so avoid a direct type + # reference here and instead check the runtime type name. + if ($pipeline.GetType().FullName -eq 'System.Management.Automation.Language.PipelineChainAst') { return Add-CommandsFromPipelineChain $pipeline $commands } return $false } -$commands = [System.Collections.ArrayList]::new() - -foreach ($statement in $ast.EndBlock.Statements) { - if (-not (Add-CommandsFromPipelineBase $statement $commands)) { - $commands = $null - break - } -} - -if ($commands -ne $null) { - $normalized = [System.Collections.ArrayList]::new() - foreach ($cmd in $commands) { - if ($cmd -is [string]) { - $null = $normalized.Add(@($cmd)) - continue - } - - if ($cmd -is [System.Array] -or $cmd -is [System.Collections.IEnumerable]) { - $null = $normalized.Add(@($cmd)) - continue - } - - $normalized = $null - break +# This script stays alive so the Rust caller can amortize PowerShell startup across +# many parse requests. Each request and response is one compact JSON line. +while (($requestLine = $stdin.ReadLine()) -ne $null) { + $request = $null + try { + $request = $requestLine | ConvertFrom-Json + } catch { + Write-Response @{ id = $null; status = 'parse_failed' } + continue } - $commands = $normalized -} + # We process requests serially, but still echo the id back so the Rust side can + # detect protocol desyncs instead of silently trusting mixed stdout. + $requestId = $request.id + $payload = $request.payload + if ([string]::IsNullOrEmpty($payload)) { + Write-Response @{ id = $requestId; status = 'parse_failed' } + continue + } -$result = if ($commands -eq $null) { - @{ status = 'unsupported' } -} else { - @{ status = 'ok'; commands = $commands } -} + try { + $source = + [System.Text.Encoding]::Unicode.GetString( + [System.Convert]::FromBase64String($payload) + ) + } catch { + Write-Response @{ id = $requestId; status = 'parse_failed' } + continue + } -,$result | ConvertTo-Json -Depth 3 + Write-Response (Invoke-ParseRequest -RequestId $requestId -Source $source) +} diff --git a/codex-rs/shell-command/src/command_safety/powershell_parser.rs b/codex-rs/shell-command/src/command_safety/powershell_parser.rs new file mode 100644 index 000000000..bf09e60be --- /dev/null +++ b/codex-rs/shell-command/src/command_safety/powershell_parser.rs @@ -0,0 +1,289 @@ +use base64::Engine; +use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; +use serde::Deserialize; +use serde::Serialize; +use std::collections::HashMap; +use std::io::BufRead; +use std::io::BufReader; +use std::io::ErrorKind; +use std::io::Write; +use std::process::Child; +use std::process::ChildStdin; +use std::process::ChildStdout; +use std::process::Command; +use std::process::Stdio; +use std::sync::LazyLock; +use std::sync::Mutex; +use std::sync::PoisonError; + +const POWERSHELL_PARSER_SCRIPT: &str = include_str!("powershell_parser.ps1"); + +/// Cache one long-lived parser process per executable path so repeated safety checks reuse +/// PowerShell startup work while still consulting the real parser every time. +/// +/// We keep the cache behind one mutex because each child process speaks a simple +/// request/response protocol over a single stdin/stdout pair, so callers targeting the same +/// executable must serialize access anyway. +pub(super) fn parse_with_powershell_ast(executable: &str, script: &str) -> PowershellParseOutcome { + static PARSER_PROCESSES: LazyLock>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + + let mut parser_processes = PARSER_PROCESSES + .lock() + .unwrap_or_else(PoisonError::into_inner); + parse_with_cached_process(&mut parser_processes, executable, script) +} + +#[derive(Debug, PartialEq, Eq)] +pub(super) enum PowershellParseOutcome { + Commands(Vec>), + Unsupported, + Failed, +} + +fn parse_with_cached_process( + parser_processes: &mut HashMap, + executable: &str, + script: &str, +) -> PowershellParseOutcome { + // `powershell.exe` and `pwsh.exe` do not accept the same language surface, so each + // executable keeps its own parser process and request stream. + let parser_key = executable.to_string(); + for attempt in 0..=1 { + if !parser_processes.contains_key(&parser_key) { + match PowershellParserProcess::spawn(executable) { + Ok(process) => { + parser_processes.insert(parser_key.clone(), process); + } + Err(_) => return PowershellParseOutcome::Failed, + } + } + + let Some(parser_process) = parser_processes.get_mut(&parser_key) else { + return PowershellParseOutcome::Failed; + }; + let parse_result = parser_process.parse(script); + match parse_result { + Ok(outcome) => return outcome, + Err(_) if attempt == 0 => { + // The common failure mode here is that a previously cached child exited or its + // stdio stream became unusable between requests. Drop that process and retry once + // with a fresh child before giving up. + parser_processes.remove(&parser_key); + } + Err(_) => return PowershellParseOutcome::Failed, + } + } + + PowershellParseOutcome::Failed +} + +fn encode_powershell_base64(script: &str) -> String { + let mut utf16 = Vec::with_capacity(script.len() * 2); + for unit in script.encode_utf16() { + utf16.extend_from_slice(&unit.to_le_bytes()); + } + BASE64_STANDARD.encode(utf16) +} + +fn encoded_parser_script() -> &'static str { + static ENCODED: LazyLock = + LazyLock::new(|| encode_powershell_base64(POWERSHELL_PARSER_SCRIPT)); + &ENCODED +} + +struct PowershellParserProcess { + child: Child, + stdin: ChildStdin, + stdout: BufReader, + // Request ids are monotonic within one child process so the caller can detect protocol + // desynchronization if stdout is contaminated or the child is unexpectedly replaced. + next_request_id: u64, +} + +impl PowershellParserProcess { + fn spawn(executable: &str) -> std::io::Result { + let mut child = Command::new(executable) + .args([ + "-NoLogo", + "-NoProfile", + "-NonInteractive", + "-EncodedCommand", + encoded_parser_script(), + ]) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::null()) + .spawn()?; + let stdin = match take_child_stdin(&mut child) { + Ok(stdin) => stdin, + Err(error) => { + kill_child(&mut child); + return Err(error); + } + }; + let stdout = match take_child_stdout(&mut child) { + Ok(stdout) => stdout, + Err(error) => { + kill_child(&mut child); + return Err(error); + } + }; + Ok(Self { + child, + stdin, + stdout, + next_request_id: 0, + }) + } + + fn parse(&mut self, script: &str) -> std::io::Result { + let request = PowershellParserRequest { + id: self.next_request_id, + payload: encode_powershell_base64(script), + }; + self.next_request_id = self.next_request_id.wrapping_add(1); + let mut request_json = serialize_request(&request)?; + request_json.push('\n'); + self.stdin.write_all(request_json.as_bytes())?; + self.stdin.flush()?; + + let mut response_line = String::new(); + if self.stdout.read_line(&mut response_line)? == 0 { + return Err(std::io::Error::new( + ErrorKind::UnexpectedEof, + "PowerShell parser closed stdout", + )); + } + + let response = deserialize_response(&response_line)?; + // Requests are serialized today; the id still catches protocol desyncs if stdout is + // contaminated or the child process is unexpectedly replaced mid-request. That turns an + // ambiguous parser result into a hard failure so the caller can discard the cached child. + if response.id != request.id { + return Err(std::io::Error::new( + ErrorKind::InvalidData, + format!( + "PowerShell parser returned response id {} for request {}", + response.id, request.id + ), + )); + } + + Ok(response.into_outcome()) + } +} + +impl Drop for PowershellParserProcess { + fn drop(&mut self) { + kill_child(&mut self.child); + } +} + +fn take_child_stdin(child: &mut Child) -> std::io::Result { + child.stdin.take().ok_or_else(|| { + std::io::Error::new( + ErrorKind::BrokenPipe, + "PowerShell parser child did not expose stdin", + ) + }) +} + +fn take_child_stdout(child: &mut Child) -> std::io::Result> { + child.stdout.take().map(BufReader::new).ok_or_else(|| { + std::io::Error::new( + ErrorKind::BrokenPipe, + "PowerShell parser child did not expose stdout", + ) + }) +} + +fn serialize_request(request: &PowershellParserRequest) -> std::io::Result { + serde_json::to_string(request).map_err(|error| { + std::io::Error::new( + ErrorKind::InvalidData, + format!("failed to serialize PowerShell parser request: {error}"), + ) + }) +} + +fn deserialize_response(response_line: &str) -> std::io::Result { + serde_json::from_str(response_line).map_err(|error| { + std::io::Error::new( + ErrorKind::InvalidData, + format!("failed to parse PowerShell parser response: {error}"), + ) + }) +} + +#[derive(Serialize)] +struct PowershellParserRequest { + id: u64, + payload: String, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct PowershellParserResponse { + id: u64, + status: String, + commands: Option>>, +} + +impl PowershellParserResponse { + fn into_outcome(self) -> PowershellParseOutcome { + match self.status.as_str() { + "ok" => self + .commands + .filter(|commands| { + !commands.is_empty() + && commands + .iter() + .all(|cmd| !cmd.is_empty() && cmd.iter().all(|word| !word.is_empty())) + }) + .map(PowershellParseOutcome::Commands) + .unwrap_or(PowershellParseOutcome::Unsupported), + "unsupported" => PowershellParseOutcome::Unsupported, + _ => PowershellParseOutcome::Failed, + } + } +} + +fn kill_child(child: &mut Child) { + let _ = child.kill(); + let _ = child.wait(); +} + +#[cfg(all(test, windows))] +mod tests { + use super::*; + use crate::powershell::try_find_powershell_executable_blocking; + use pretty_assertions::assert_eq; + + #[test] + fn parser_process_handles_multiple_requests() { + let Some(powershell) = try_find_powershell_executable_blocking() else { + return; + }; + let powershell = powershell.as_path().to_str().unwrap(); + let mut parser = PowershellParserProcess::spawn(powershell).unwrap(); + + let first = parser.parse("Get-Content 'foo bar'").unwrap(); + assert_eq!( + first, + PowershellParseOutcome::Commands(vec![vec![ + "Get-Content".to_string(), + "foo bar".to_string(), + ]]), + ); + + let second = parser.parse("Write-Output foo | Measure-Object").unwrap(); + assert_eq!( + second, + PowershellParseOutcome::Commands(vec![ + vec!["Write-Output".to_string(), "foo".to_string()], + vec!["Measure-Object".to_string()], + ]), + ); + } +} diff --git a/codex-rs/shell-command/src/command_safety/windows_safe_commands.rs b/codex-rs/shell-command/src/command_safety/windows_safe_commands.rs index aa7ff9681..b6c5d863b 100644 --- a/codex-rs/shell-command/src/command_safety/windows_safe_commands.rs +++ b/codex-rs/shell-command/src/command_safety/windows_safe_commands.rs @@ -1,12 +1,7 @@ use crate::command_safety::is_dangerous_command::git_global_option_requires_prompt; -use base64::Engine; -use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; -use serde::Deserialize; +use crate::command_safety::powershell_parser::PowershellParseOutcome; +use crate::command_safety::powershell_parser::parse_with_powershell_ast; use std::path::Path; -use std::process::Command; -use std::sync::LazyLock; - -const POWERSHELL_PARSER_SCRIPT: &str = include_str!("powershell_parser.ps1"); /// On Windows, we conservatively allow only clearly read-only PowerShell invocations /// that match a small safelist. Anything else (including direct CMD commands) is unsafe. @@ -122,82 +117,6 @@ fn is_powershell_executable(exe: &str) -> bool { ) } -/// Attempts to parse PowerShell using the real PowerShell parser, returning every pipeline element -/// as a flat argv vector when possible. If parsing fails or the AST includes unsupported constructs, -/// we conservatively reject the command instead of trying to split it manually. -fn parse_with_powershell_ast(executable: &str, script: &str) -> PowershellParseOutcome { - let encoded_script = encode_powershell_base64(script); - let encoded_parser_script = encoded_parser_script(); - match Command::new(executable) - .args([ - "-NoLogo", - "-NoProfile", - "-NonInteractive", - "-EncodedCommand", - encoded_parser_script, - ]) - .env("CODEX_POWERSHELL_PAYLOAD", &encoded_script) - .output() - { - Ok(output) if output.status.success() => { - if let Ok(result) = - serde_json::from_slice::(output.stdout.as_slice()) - { - result.into_outcome() - } else { - PowershellParseOutcome::Failed - } - } - _ => PowershellParseOutcome::Failed, - } -} - -fn encode_powershell_base64(script: &str) -> String { - let mut utf16 = Vec::with_capacity(script.len() * 2); - for unit in script.encode_utf16() { - utf16.extend_from_slice(&unit.to_le_bytes()); - } - BASE64_STANDARD.encode(utf16) -} - -fn encoded_parser_script() -> &'static str { - static ENCODED: LazyLock = - LazyLock::new(|| encode_powershell_base64(POWERSHELL_PARSER_SCRIPT)); - &ENCODED -} - -#[derive(Deserialize)] -#[serde(deny_unknown_fields)] -struct PowershellParserOutput { - status: String, - commands: Option>>, -} - -impl PowershellParserOutput { - fn into_outcome(self) -> PowershellParseOutcome { - match self.status.as_str() { - "ok" => self - .commands - .filter(|commands| { - !commands.is_empty() - && commands - .iter() - .all(|cmd| !cmd.is_empty() && cmd.iter().all(|word| !word.is_empty())) - }) - .map(PowershellParseOutcome::Commands) - .unwrap_or(PowershellParseOutcome::Unsupported), - "unsupported" => PowershellParseOutcome::Unsupported, - _ => PowershellParseOutcome::Failed, - } - } -} - -enum PowershellParseOutcome { - Commands(Vec>), - Unsupported, - Failed, -} - fn join_arguments_as_script(args: &[String]) -> String { let mut words = Vec::with_capacity(args.len()); if let Some((first, rest)) = args.split_first() {