diff --git a/codex-rs/memories/mcp/src/backend.rs b/codex-rs/memories/mcp/src/backend.rs index a4b542cd6..f402a9b8a 100644 --- a/codex-rs/memories/mcp/src/backend.rs +++ b/codex-rs/memories/mcp/src/backend.rs @@ -87,11 +87,15 @@ pub struct SearchMemoriesResponse { pub truncated: bool, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] -#[serde(rename_all = "snake_case")] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type", rename_all = "snake_case")] pub enum SearchMatchMode { Any, - All, + AllOnSameLine, + AllWithinLines { + #[schemars(range(min = 1))] + line_count: usize, + }, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, JsonSchema)] @@ -136,6 +140,8 @@ pub enum MemoriesBackendError { NotFile { path: String }, #[error("queries must not be empty or contain empty strings")] EmptyQuery, + #[error("all_within_lines.line_count must be a positive integer")] + InvalidMatchWindow, #[error("I/O error while reading memories: {0}")] Io(#[from] std::io::Error), } diff --git a/codex-rs/memories/mcp/src/local.rs b/codex-rs/memories/mcp/src/local.rs index 0f642d43f..f16dba195 100644 --- a/codex-rs/memories/mcp/src/local.rs +++ b/codex-rs/memories/mcp/src/local.rs @@ -224,6 +224,12 @@ impl MemoriesBackend for LocalMemoriesBackend { if queries.is_empty() || queries.iter().any(std::string::String::is_empty) { return Err(MemoriesBackendError::EmptyQuery); } + if matches!( + request.match_mode, + SearchMatchMode::AllWithinLines { line_count: 0 } + ) { + return Err(MemoriesBackendError::InvalidMatchWindow); + } let max_results = request.max_results.min(MAX_SEARCH_RESULTS); let start = self.resolve_scoped_path(request.path.as_deref()).await?; @@ -240,8 +246,11 @@ impl MemoriesBackend for LocalMemoriesBackend { }; reject_symlink(&display_relative_path(&self.root, &start), &metadata)?; - let matcher = - SearchMatcher::new(queries.clone(), request.match_mode, request.case_sensitive); + let matcher = SearchMatcher::new( + queries.clone(), + request.match_mode.clone(), + request.case_sensitive, + ); let mut matches = Vec::new(); search_entries( &self.root, @@ -329,26 +338,116 @@ async fn search_file( Err(err) => return Err(err.into()), }; let lines = content.lines().collect::>(); - for (idx, line) in lines.iter().enumerate() { - let matched_queries = matcher.matched_queries(line); - if !matched_queries.is_empty() { - let start_index = idx.saturating_sub(context_lines); - let end_index = idx - .saturating_add(context_lines) - .saturating_add(1) - .min(lines.len()); - matches.push(MemorySearchMatch { - path: display_relative_path(root, path), - match_line_number: idx + 1, - content_start_line_number: start_index + 1, - content: lines[start_index..end_index].join("\n"), - matched_queries, - }); + let line_matches = lines + .iter() + .map(|line| matcher.matched_query_flags(line)) + .collect::>(); + match &matcher.match_mode { + SearchMatchMode::Any => { + for (idx, matched_query_flags) in line_matches.iter().enumerate() { + if matched_query_flags.iter().any(|matched| *matched) { + matches.push(build_search_match( + root, + path, + &lines, + idx, + idx, + context_lines, + matcher.matched_queries(matched_query_flags), + )); + } + } + } + SearchMatchMode::AllOnSameLine => { + for (idx, matched_query_flags) in line_matches.iter().enumerate() { + if matched_query_flags.iter().all(|matched| *matched) { + matches.push(build_search_match( + root, + path, + &lines, + idx, + idx, + context_lines, + matcher.matched_queries(matched_query_flags), + )); + } + } + } + SearchMatchMode::AllWithinLines { line_count } => { + let mut windows = Vec::new(); + for start_index in 0..lines.len() { + if !line_matches[start_index].iter().any(|matched| *matched) { + continue; + } + let last_allowed_index = start_index + .saturating_add(line_count.saturating_sub(1)) + .min(lines.len().saturating_sub(1)); + let mut matched_query_flags = vec![false; matcher.queries.len()]; + for (end_index, line_match_flags) in line_matches + .iter() + .enumerate() + .take(last_allowed_index + 1) + .skip(start_index) + { + for (idx, matched) in line_match_flags.iter().enumerate() { + matched_query_flags[idx] |= matched; + } + if matched_query_flags.iter().all(|matched| *matched) { + windows.push((start_index, end_index, matched_query_flags)); + break; + } + } + } + for (idx, (start_index, end_index, matched_query_flags)) in windows.iter().enumerate() { + let strictly_contains_another_window = windows.iter().enumerate().any( + |(other_idx, (other_start_index, other_end_index, _))| { + idx != other_idx + && start_index <= other_start_index + && end_index >= other_end_index + && (start_index != other_start_index || end_index != other_end_index) + }, + ); + if strictly_contains_another_window { + continue; + } + matches.push(build_search_match( + root, + path, + &lines, + *start_index, + *end_index, + context_lines, + matcher.matched_queries(matched_query_flags), + )); + } } } Ok(()) } +fn build_search_match( + root: &Path, + path: &Path, + lines: &[&str], + match_start_index: usize, + match_end_index: usize, + context_lines: usize, + matched_queries: Vec, +) -> MemorySearchMatch { + let content_start_index = match_start_index.saturating_sub(context_lines); + let content_end_index = match_end_index + .saturating_add(context_lines) + .saturating_add(1) + .min(lines.len()); + MemorySearchMatch { + path: display_relative_path(root, path), + match_line_number: match_start_index + 1, + content_start_line_number: content_start_index + 1, + content: lines[content_start_index..content_end_index].join("\n"), + matched_queries, + } +} + struct SearchMatcher { queries: Vec, normalized_queries: Option>, @@ -370,23 +469,24 @@ impl SearchMatcher { } } - fn matched_queries(&self, line: &str) -> Vec { + fn matched_query_flags(&self, line: &str) -> Vec { let line = match self.normalized_queries.as_ref() { Some(_) => Cow::Owned(line.to_lowercase()), None => Cow::Borrowed(line), }; let queries = self.normalized_queries.as_deref().unwrap_or(&self.queries); - let mut matched_queries = Vec::new(); - for (idx, query) in queries.iter().enumerate() { - if line.as_ref().contains(query) { - matched_queries.push(self.queries[idx].clone()); - } - } - match self.match_mode { - SearchMatchMode::Any => matched_queries, - SearchMatchMode::All if matched_queries.len() == self.queries.len() => matched_queries, - SearchMatchMode::All => Vec::new(), - } + queries + .iter() + .map(|query| line.as_ref().contains(query)) + .collect() + } + + fn matched_queries(&self, matched_query_flags: &[bool]) -> Vec { + self.queries + .iter() + .zip(matched_query_flags) + .filter_map(|(query, matched)| matched.then_some(query.clone())) + .collect() } } diff --git a/codex-rs/memories/mcp/src/local_tests.rs b/codex-rs/memories/mcp/src/local_tests.rs index 777348953..4edf06eef 100644 --- a/codex-rs/memories/mcp/src/local_tests.rs +++ b/codex-rs/memories/mcp/src/local_tests.rs @@ -752,7 +752,7 @@ async fn search_supports_case_insensitive_matching() { } #[tokio::test] -async fn search_supports_any_and_all_match_modes() { +async fn search_supports_any_and_all_on_same_line_match_modes() { let tempdir = TempDir::new().expect("tempdir"); tokio::fs::write( tempdir.path().join("MEMORY.md"), @@ -793,11 +793,11 @@ async fn search_supports_any_and_all_match_modes() { ); let mut request = search_request(&["alpha", "needle"]); - request.match_mode = SearchMatchMode::All; + request.match_mode = SearchMatchMode::AllOnSameLine; let all_response = backend(&tempdir) .search(request) .await - .expect("search with all match mode"); + .expect("search with all-on-same-line match mode"); assert_eq!( all_response.matches, vec![MemorySearchMatch { @@ -810,6 +810,63 @@ async fn search_supports_any_and_all_match_modes() { ); } +#[tokio::test] +async fn search_supports_all_within_lines_match_mode() { + let tempdir = TempDir::new().expect("tempdir"); + tokio::fs::write( + tempdir.path().join("MEMORY.md"), + "alpha first\nmiddle\nneedle later\nalpha again needle together\n", + ) + .await + .expect("write memory file"); + + let mut request = search_request(&["alpha", "needle"]); + request.match_mode = SearchMatchMode::AllWithinLines { line_count: 3 }; + request.context_lines = 1; + let response = backend(&tempdir) + .search(request) + .await + .expect("search with all-within-lines match mode"); + + assert_eq!( + response.matches, + vec![ + MemorySearchMatch { + path: "MEMORY.md".to_string(), + match_line_number: 1, + content_start_line_number: 1, + content: "alpha first\nmiddle\nneedle later\nalpha again needle together" + .to_string(), + matched_queries: vec!["alpha".to_string(), "needle".to_string()], + }, + MemorySearchMatch { + path: "MEMORY.md".to_string(), + match_line_number: 4, + content_start_line_number: 3, + content: "needle later\nalpha again needle together".to_string(), + matched_queries: vec!["alpha".to_string(), "needle".to_string()], + }, + ] + ); +} + +#[tokio::test] +async fn search_rejects_zero_line_window() { + let tempdir = TempDir::new().expect("tempdir"); + tokio::fs::write(tempdir.path().join("MEMORY.md"), "needle\n") + .await + .expect("write memory file"); + + let mut request = search_request(&["needle"]); + request.match_mode = SearchMatchMode::AllWithinLines { line_count: 0 }; + let err = backend(&tempdir) + .search(request) + .await + .expect_err("zero-width windows should be rejected"); + + assert!(matches!(err, MemoriesBackendError::InvalidMatchWindow)); +} + #[tokio::test] async fn search_rejects_invalid_cursor() { let tempdir = TempDir::new().expect("tempdir"); diff --git a/codex-rs/memories/mcp/src/server.rs b/codex-rs/memories/mcp/src/server.rs index 25cc7252b..88eeb8fa1 100644 --- a/codex-rs/memories/mcp/src/server.rs +++ b/codex-rs/memories/mcp/src/server.rs @@ -227,7 +227,7 @@ fn search_tool() -> Tool { let mut tool = Tool::new( Cow::Borrowed(SEARCH_TOOL_NAME), Cow::Borrowed( - "Search Codex memory files for line-based substring matches, optionally requiring any or all query substrings on the same line.", + "Search Codex memory files for substring matches, optionally requiring all query substrings on the same line or within a line window.", ), Arc::new(schema::input_schema_for::()), ); @@ -273,7 +273,10 @@ fn backend_error_to_mcp(err: MemoriesBackendError) -> McpError { | MemoriesBackendError::InvalidMaxLines | MemoriesBackendError::LineOffsetExceedsFileLength | MemoriesBackendError::NotFile { .. } - | MemoriesBackendError::EmptyQuery => McpError::invalid_params(err.to_string(), None), + | MemoriesBackendError::EmptyQuery + | MemoriesBackendError::InvalidMatchWindow => { + McpError::invalid_params(err.to_string(), None) + } MemoriesBackendError::Io(_) => McpError::internal_error(err.to_string(), None), } } @@ -308,6 +311,33 @@ mod tests { ); } + #[test] + fn search_args_accept_windowed_all_match_mode() { + let args: SearchArgs = parse_args(json!({ + "queries": ["alpha", "needle"], + "match_mode": { + "type": "all_within_lines", + "line_count": 3 + } + })) + .expect("windowed all args should parse"); + + let request = args.into_request(); + + assert_eq!( + request, + SearchMemoriesRequest { + queries: vec!["alpha".to_string(), "needle".to_string()], + match_mode: SearchMatchMode::AllWithinLines { line_count: 3 }, + path: None, + cursor: None, + context_lines: 0, + case_sensitive: true, + max_results: DEFAULT_SEARCH_MAX_RESULTS, + } + ); + } + #[test] fn search_args_reject_legacy_single_query() { let err = parse_args::(json!({