diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index e9c171d77..f5e8b5623 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -3293,6 +3293,7 @@ name = "codex-message-history" version = "0.0.0" dependencies = [ "codex-config", + "memchr", "pretty_assertions", "serde", "serde_json", @@ -3407,6 +3408,7 @@ dependencies = [ "codex-core", "codex-model-provider-info", "futures", + "memchr", "pretty_assertions", "reqwest 0.12.28", "semver", @@ -3576,6 +3578,7 @@ dependencies = [ "codex-utils-pty", "futures", "keyring", + "memchr", "oauth2", "pretty_assertions", "reqwest 0.13.4", diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 59d161958..15e0e2c01 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -318,6 +318,7 @@ libc = "0.2.182" log = "0.4" lru = "0.16.3" maplit = "1.0.2" +memchr = "2.7.6" mime_guess = "2.0.5" multimap = "0.10.0" notify = "8.2.0" diff --git a/codex-rs/message-history/Cargo.toml b/codex-rs/message-history/Cargo.toml index b67933d1d..176b95790 100644 --- a/codex-rs/message-history/Cargo.toml +++ b/codex-rs/message-history/Cargo.toml @@ -14,6 +14,7 @@ workspace = true [dependencies] codex-config = { workspace = true } +memchr = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tokio = { workspace = true, features = ["fs", "io-util", "rt"] } diff --git a/codex-rs/message-history/src/lib.rs b/codex-rs/message-history/src/lib.rs index 0d85cb8fe..18a6d791e 100644 --- a/codex-rs/message-history/src/lib.rs +++ b/codex-rs/message-history/src/lib.rs @@ -26,6 +26,7 @@ use std::io::Write; use std::path::Path; use std::path::PathBuf; +use memchr::memchr_iter; use serde::Deserialize; use serde::Serialize; @@ -43,6 +44,7 @@ use std::os::unix::fs::PermissionsExt; /// Filename that stores the message history inside `~/.codex`. const HISTORY_FILENAME: &str = "history.jsonl"; +const HISTORY_READ_BUFFER_SIZE: usize = 8192; /// When history exceeds the hard cap, trim it down to this fraction of `max_bytes`. const HISTORY_SOFT_CAP_RATIO: f64 = 0.8; @@ -330,13 +332,13 @@ async fn history_metadata_for_file(path: &Path) -> (u64, usize) { }; // Count newline bytes. - let mut buf = [0u8; 8192]; + let mut buf = [0u8; HISTORY_READ_BUFFER_SIZE]; let mut count = 0usize; loop { match file.read(&mut buf).await { Ok(0) => break, Ok(n) => { - count += buf[..n].iter().filter(|&&b| b == b'\n').count(); + count += memchr_iter(b'\n', &buf[..n]).count(); } Err(_) => return (log_id, 0), } diff --git a/codex-rs/message-history/src/tests.rs b/codex-rs/message-history/src/tests.rs index 88f0b7e00..f4c7e66e1 100644 --- a/codex-rs/message-history/src/tests.rs +++ b/codex-rs/message-history/src/tests.rs @@ -41,6 +41,28 @@ async fn lookup_reads_history_entries() { assert_eq!(second_entry, entries[1]); } +#[tokio::test] +async fn history_metadata_counts_newlines_across_read_boundaries() { + let temp_dir = TempDir::new().expect("create temp dir"); + let history_path = temp_dir.path().join(HISTORY_FILENAME); + let mut contents = vec![b'x'; 3 * HISTORY_READ_BUFFER_SIZE + 1]; + let newline_offsets = [ + 0, + HISTORY_READ_BUFFER_SIZE - 1, + HISTORY_READ_BUFFER_SIZE, + 2 * HISTORY_READ_BUFFER_SIZE, + contents.len() - 2, + ]; + for offset in newline_offsets { + contents[offset] = b'\n'; + } + std::fs::write(&history_path, contents).expect("write history file"); + + let (_, count) = history_metadata_for_file(&history_path).await; + + assert_eq!(count, newline_offsets.len()); +} + #[tokio::test] async fn lookup_uses_stable_log_id_after_appends() { let temp_dir = TempDir::new().expect("create temp dir"); diff --git a/codex-rs/ollama/Cargo.toml b/codex-rs/ollama/Cargo.toml index 5d30fbda2..42e32159f 100644 --- a/codex-rs/ollama/Cargo.toml +++ b/codex-rs/ollama/Cargo.toml @@ -18,6 +18,7 @@ bytes = { workspace = true } codex-core = { workspace = true } codex-model-provider-info = { workspace = true } futures = { workspace = true } +memchr = { workspace = true } reqwest = { workspace = true, features = ["json", "stream"] } semver = { workspace = true } serde_json = { workspace = true } diff --git a/codex-rs/ollama/src/client.rs b/codex-rs/ollama/src/client.rs index 96601d086..232c14683 100644 --- a/codex-rs/ollama/src/client.rs +++ b/codex-rs/ollama/src/client.rs @@ -1,4 +1,3 @@ -use bytes::BytesMut; use futures::StreamExt; use futures::stream::BoxStream; use semver::Version; @@ -6,6 +5,7 @@ use serde_json::Value as JsonValue; use std::collections::VecDeque; use std::io; +use crate::line_buffer::LineBuffer; use crate::parser::pull_events_from_value; use crate::pull::PullEvent; use crate::pull::PullProgressReporter; @@ -174,7 +174,7 @@ impl OllamaClient { } let mut stream = resp.bytes_stream(); - let mut buf = BytesMut::new(); + let mut buf = LineBuffer::default(); let _pending: VecDeque = VecDeque::new(); // Using an async stream adaptor backed by unfold-like manual loop. @@ -183,8 +183,7 @@ impl OllamaClient { match chunk { Ok(bytes) => { buf.extend_from_slice(&bytes); - while let Some(pos) = buf.iter().position(|b| *b == b'\n') { - let line = buf.split_to(pos + 1); + while let Some(line) = buf.take_line() { if let Ok(text) = std::str::from_utf8(&line) { let text = text.trim(); if text.is_empty() { continue; } @@ -263,6 +262,7 @@ impl OllamaClient { #[cfg(test)] mod tests { use super::*; + use assert_matches::assert_matches; use pretty_assertions::assert_eq; // Happy-path tests using a mock HTTP server; skip if sandbox network is disabled. @@ -333,6 +333,50 @@ mod tests { assert_eq!(version, Some(Version::new(0, 14, 1))); } + #[tokio::test] + async fn test_pull_model_stream_parses_large_json_lines() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} set; skipping test_pull_model_stream_parses_large_json_lines", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + let body = format!( + "{}\n{}\n", + serde_json::json!({ + "status": "pulling layers", + "padding": "x".repeat(128 * 1024), + }), + serde_json::json!({"status": "complete"}), + ); + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path("/api/pull")) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_raw(body, "application/x-ndjson"), + ) + .mount(&server) + .await; + + let client = OllamaClient::from_host_root(server.uri()); + let events = client + .pull_model_stream("test-model") + .await + .expect("start pull stream") + .collect::>() + .await; + + assert_matches!( + events.as_slice(), + [ + PullEvent::Status(pulling), + PullEvent::Status(complete), + ] if pulling == "pulling layers" && complete == "complete" + ); + } + #[tokio::test] async fn test_probe_server_happy_path_openai_compat_and_native() { if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { diff --git a/codex-rs/ollama/src/lib.rs b/codex-rs/ollama/src/lib.rs index 75800845b..d0de830bc 100644 --- a/codex-rs/ollama/src/lib.rs +++ b/codex-rs/ollama/src/lib.rs @@ -1,4 +1,5 @@ mod client; +mod line_buffer; mod parser; mod pull; mod url; diff --git a/codex-rs/ollama/src/line_buffer.rs b/codex-rs/ollama/src/line_buffer.rs new file mode 100644 index 000000000..fb3d4b98f --- /dev/null +++ b/codex-rs/ollama/src/line_buffer.rs @@ -0,0 +1,32 @@ +use bytes::BytesMut; +use memchr::memchr; + +#[derive(Default)] +#[cfg_attr(test, derive(Debug, PartialEq, Eq))] +pub(crate) struct LineBuffer { + bytes: BytesMut, + /// Prefix already scanned and known not to contain a newline. + scanned_len: usize, +} + +impl LineBuffer { + pub(crate) fn extend_from_slice(&mut self, bytes: &[u8]) { + self.bytes.extend_from_slice(bytes); + } + + pub(crate) fn take_line(&mut self) -> Option { + let Some(relative_index) = memchr(b'\n', &self.bytes[self.scanned_len..]) else { + self.scanned_len = self.bytes.len(); + return None; + }; + + let newline_index = self.scanned_len + relative_index; + let line = self.bytes.split_to(newline_index + 1); + self.scanned_len = 0; + Some(line) + } +} + +#[cfg(test)] +#[path = "line_buffer_tests.rs"] +mod tests; diff --git a/codex-rs/ollama/src/line_buffer_tests.rs b/codex-rs/ollama/src/line_buffer_tests.rs new file mode 100644 index 000000000..58de07e86 --- /dev/null +++ b/codex-rs/ollama/src/line_buffer_tests.rs @@ -0,0 +1,42 @@ +use bytes::BytesMut; +use pretty_assertions::assert_eq; + +use super::LineBuffer; + +#[test] +fn searches_only_new_bytes_after_partial_line() { + let mut buffer = LineBuffer::default(); + + buffer.extend_from_slice(b"partial"); + assert_eq!(buffer.take_line(), None); + assert_eq!( + buffer, + LineBuffer { + bytes: BytesMut::from(&b"partial"[..]), + scanned_len: 7, + } + ); + + buffer.extend_from_slice(b" line"); + assert_eq!(buffer.take_line(), None); + assert_eq!( + buffer, + LineBuffer { + bytes: BytesMut::from(&b"partial line"[..]), + scanned_len: 12, + } + ); + + buffer.extend_from_slice(b"\nnext"); + assert_eq!( + buffer.take_line(), + Some(BytesMut::from(&b"partial line\n"[..])) + ); + assert_eq!( + buffer, + LineBuffer { + bytes: BytesMut::from(&b"next"[..]), + scanned_len: 0, + } + ); +} diff --git a/codex-rs/rmcp-client/Cargo.toml b/codex-rs/rmcp-client/Cargo.toml index 63c6d74e4..e3417de70 100644 --- a/codex-rs/rmcp-client/Cargo.toml +++ b/codex-rs/rmcp-client/Cargo.toml @@ -26,6 +26,7 @@ codex-utils-home-dir = { workspace = true } bytes = { workspace = true } futures = { workspace = true, default-features = false, features = ["std"] } keyring = { workspace = true, features = ["crypto-rust"] } +memchr = { workspace = true } oauth2 = "5" reqwest = { version = "0.13", default-features = false, features = [ "json", diff --git a/codex-rs/rmcp-client/src/executor_process_transport.rs b/codex-rs/rmcp-client/src/executor_process_transport.rs index 41f0b7660..cff4f517e 100644 --- a/codex-rs/rmcp-client/src/executor_process_transport.rs +++ b/codex-rs/rmcp-client/src/executor_process_transport.rs @@ -20,11 +20,11 @@ use std::future::Future; use std::io; -use std::mem::take; use std::sync::Arc; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; +use bytes::BytesMut; use codex_exec_server::ExecOutputStream; use codex_exec_server::ExecProcess; use codex_exec_server::ExecProcessEvent; @@ -32,6 +32,7 @@ use codex_exec_server::ExecProcessEventReceiver; use codex_exec_server::ProcessId; use codex_exec_server::ProcessOutputChunk; use codex_exec_server::WriteStatus; +use memchr::memchr; use rmcp::service::RoleClient; use rmcp::service::RxJsonRpcMessage; use rmcp::service::TxJsonRpcMessage; @@ -46,6 +47,42 @@ use tracing::warn; static PROCESS_COUNTER: AtomicUsize = AtomicUsize::new(1); +#[derive(Default)] +#[cfg_attr(test, derive(Debug, PartialEq, Eq))] +struct LineBuffer { + bytes: BytesMut, + /// Prefix already scanned and known not to contain a newline. + scanned_len: usize, +} + +impl LineBuffer { + fn extend_from_slice(&mut self, bytes: &[u8]) { + self.bytes.extend_from_slice(bytes); + } + + fn take_line(&mut self) -> Option { + let Some(relative_index) = memchr(b'\n', &self.bytes[self.scanned_len..]) else { + self.scanned_len = self.bytes.len(); + return None; + }; + + let newline_index = self.scanned_len + relative_index; + let mut line = self.bytes.split_to(newline_index + 1); + line.truncate(newline_index); + self.scanned_len = 0; + Some(line) + } + + fn take_remaining(&mut self) -> Option { + if self.bytes.is_empty() { + return None; + } + + self.scanned_len = 0; + Some(self.bytes.split()) + } +} + // Remote public implementation. /// A client-side rmcp transport backed by an executor-managed process. @@ -73,10 +110,10 @@ pub(super) struct ExecutorProcessTransport { /// Buffered child stdout bytes that have not yet formed a complete /// newline-delimited JSON-RPC message. - stdout: Vec, + stdout: LineBuffer, /// Buffered stderr bytes for diagnostic logging. - stderr: Vec, + stderr: LineBuffer, /// Whether the executor has reported process closure or a terminal /// subscription failure. Once closed, any remaining partial stdout line is @@ -105,8 +142,8 @@ impl ExecutorProcessTransport { process, events, program_name, - stdout: Vec::new(), - stderr: Vec::new(), + stdout: LineBuffer::default(), + stderr: LineBuffer::default(), closed: false, terminated: false, last_seq: 0, @@ -288,15 +325,10 @@ impl ExecutorProcessTransport { // so EOF after a complete JSON object behaves like local rmcp's // `decode_eof` handling. loop { - let line_end = self.stdout.iter().position(|byte| *byte == b'\n'); - let line = match (line_end, allow_partial && !self.stdout.is_empty()) { - (Some(index), _) => { - let mut line = self.stdout.drain(..=index).collect::>(); - line.pop(); - line - } - (None, true) => self.stdout.drain(..).collect(), - (None, false) => return None, + let line = match self.stdout.take_line() { + Some(line) => line, + None if allow_partial => self.stdout.take_remaining()?, + None => return None, }; let line = Self::trim_trailing_carriage_return(line); match from_slice::>(&line) { @@ -315,12 +347,8 @@ impl ExecutorProcessTransport { // Keep stderr line-oriented in logs so a chatty MCP server does not // produce one log record per byte chunk. self.stderr.extend_from_slice(bytes); - while let Some(index) = self.stderr.iter().position(|byte| *byte == b'\n') { - let mut line = self.stderr.drain(..=index).collect::>(); - line.pop(); - if line.last() == Some(&b'\r') { - line.pop(); - } + while let Some(line) = self.stderr.take_line() { + let line = Self::trim_trailing_carriage_return(line); info!( "MCP server stderr ({}): {}", self.program_name, @@ -330,10 +358,9 @@ impl ExecutorProcessTransport { } fn flush_stderr(&mut self) { - if self.stderr.is_empty() { + let Some(line) = self.stderr.take_remaining() else { return; - } - let line = take(&mut self.stderr); + }; info!( "MCP server stderr ({}): {}", self.program_name, @@ -341,14 +368,18 @@ impl ExecutorProcessTransport { ); } - fn trim_trailing_carriage_return(mut line: Vec) -> Vec { + fn trim_trailing_carriage_return(mut line: BytesMut) -> BytesMut { if line.last() == Some(&b'\r') { - line.pop(); + line.truncate(line.len() - 1); } line } } +#[cfg(test)] +#[path = "executor_process_transport_tests.rs"] +mod tests; + impl Drop for ExecutorProcessTransport { fn drop(&mut self) { if self.terminated { diff --git a/codex-rs/rmcp-client/src/executor_process_transport_tests.rs b/codex-rs/rmcp-client/src/executor_process_transport_tests.rs new file mode 100644 index 000000000..aeef11a76 --- /dev/null +++ b/codex-rs/rmcp-client/src/executor_process_transport_tests.rs @@ -0,0 +1,72 @@ +use bytes::BytesMut; +use pretty_assertions::assert_eq; + +use super::LineBuffer; + +#[test] +fn searches_only_new_bytes_after_partial_line() { + let mut buffer = LineBuffer::default(); + + buffer.extend_from_slice(b"partial"); + assert_eq!(buffer.take_line(), None); + assert_eq!( + buffer, + LineBuffer { + bytes: BytesMut::from(&b"partial"[..]), + scanned_len: 7, + } + ); + + buffer.extend_from_slice(b" line"); + assert_eq!(buffer.take_line(), None); + assert_eq!( + buffer, + LineBuffer { + bytes: BytesMut::from(&b"partial line"[..]), + scanned_len: 12, + } + ); + + buffer.extend_from_slice(b"\nnext"); + assert_eq!( + buffer.take_line(), + Some(BytesMut::from(&b"partial line"[..])) + ); + assert_eq!( + buffer, + LineBuffer { + bytes: BytesMut::from(&b"next"[..]), + scanned_len: 0, + } + ); +} + +#[test] +fn splits_multiple_lines_and_retains_partial_tail() { + let mut buffer = LineBuffer::default(); + buffer.extend_from_slice(b"first\nsecond\npartial"); + + assert_eq!(buffer.take_line(), Some(BytesMut::from(&b"first"[..]))); + assert_eq!(buffer.take_line(), Some(BytesMut::from(&b"second"[..]))); + assert_eq!(buffer.take_line(), None); + assert_eq!( + buffer, + LineBuffer { + bytes: BytesMut::from(&b"partial"[..]), + scanned_len: 7, + } + ); +} + +#[test] +fn takes_unterminated_remaining_bytes_at_eof() { + let mut buffer = LineBuffer::default(); + buffer.extend_from_slice(b"remaining"); + assert_eq!(buffer.take_line(), None); + + assert_eq!( + buffer.take_remaining(), + Some(BytesMut::from(&b"remaining"[..])) + ); + assert_eq!(buffer, LineBuffer::default()); +}