diff --git a/codex-rs/core/src/session/session.rs b/codex-rs/core/src/session/session.rs index dff94692e..be3b3e3a3 100644 --- a/codex-rs/core/src/session/session.rs +++ b/codex-rs/core/src/session/session.rs @@ -1033,6 +1033,7 @@ impl Session { ), ), code_mode_service: crate::tools::code_mode::CodeModeService::new(), + tool_search_handler_cache: Default::default(), environment_manager, }; let (out_of_band_elicitation_paused, _out_of_band_elicitation_paused_rx) = diff --git a/codex-rs/core/src/session/tests.rs b/codex-rs/core/src/session/tests.rs index d328a6a5a..9eee4da39 100644 --- a/codex-rs/core/src/session/tests.rs +++ b/codex-rs/core/src/session/tests.rs @@ -571,6 +571,7 @@ fn test_tool_runtime(session: Arc, turn_context: Arc) -> T extension_tool_executors: Vec::new(), dynamic_tools: turn_context.dynamic_tools.as_slice(), }, + &Default::default(), )); let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); ToolCallRuntime::new(router, session, turn_context, tracker) @@ -5011,6 +5012,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) { /*attestation_provider*/ None, ), code_mode_service: crate::tools::code_mode::CodeModeService::new(), + tool_search_handler_cache: Default::default(), environment_manager: Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), }; @@ -7016,6 +7018,7 @@ where /*attestation_provider*/ None, ), code_mode_service: crate::tools::code_mode::CodeModeService::new(), + tool_search_handler_cache: Default::default(), environment_manager: Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()), }; @@ -9531,6 +9534,7 @@ async fn fatal_tool_error_stops_turn_and_reports_error() { extension_tool_executors: Vec::new(), dynamic_tools: turn_context.dynamic_tools.as_slice(), }, + &Default::default(), ); let item = ResponseItem::CustomToolCall { id: None, diff --git a/codex-rs/core/src/session/turn.rs b/codex-rs/core/src/session/turn.rs index 49e64e85e..3df795296 100644 --- a/codex-rs/core/src/session/turn.rs +++ b/codex-rs/core/src/session/turn.rs @@ -1234,6 +1234,7 @@ pub(crate) async fn built_tools( extension_tool_executors: extension_tool_executors(sess), dynamic_tools: turn_context.dynamic_tools.as_slice(), }, + &sess.services.tool_search_handler_cache, ))) } diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs index 2901b2ae5..30ef37048 100644 --- a/codex-rs/core/src/state/service.rs +++ b/codex-rs/core/src/state/service.rs @@ -13,6 +13,7 @@ use crate::guardian::GuardianRejectionCircuitBreaker; use crate::mcp::McpManager; use crate::shell_snapshot::ShellSnapshot; use crate::tools::code_mode::CodeModeService; +use crate::tools::handlers::ToolSearchHandlerCache; use crate::tools::network_approval::NetworkApprovalService; use crate::tools::sandboxing::ApprovalStore; use crate::unified_exec::UnifiedExecProcessManager; @@ -81,6 +82,7 @@ pub(crate) struct SessionServices { /// Session-scoped model client shared across turns. pub(crate) model_client: ModelClient, pub(crate) code_mode_service: CodeModeService, + pub(crate) tool_search_handler_cache: ToolSearchHandlerCache, /// Shared process-level environment registry. Sessions carry an `Arc` handle so they can pass /// the same manager through child-thread spawn paths without reconstructing it. pub(crate) environment_manager: Arc, diff --git a/codex-rs/core/src/stream_events_utils_tests.rs b/codex-rs/core/src/stream_events_utils_tests.rs index c151a46fb..560f917fb 100644 --- a/codex-rs/core/src/stream_events_utils_tests.rs +++ b/codex-rs/core/src/stream_events_utils_tests.rs @@ -276,6 +276,7 @@ async fn handle_output_item_done_returns_contributed_last_agent_message() { extension_tool_executors: Vec::new(), dynamic_tools: turn_context.dynamic_tools.as_slice(), }, + &Default::default(), )); let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); let tool_runtime = ToolCallRuntime::new( diff --git a/codex-rs/core/src/tools/handlers/mod.rs b/codex-rs/core/src/tools/handlers/mod.rs index 258c9914d..7791a819a 100644 --- a/codex-rs/core/src/tools/handlers/mod.rs +++ b/codex-rs/core/src/tools/handlers/mod.rs @@ -69,7 +69,7 @@ pub use request_user_input::RequestUserInputHandler; pub use shell::ShellCommandHandler; pub(crate) use shell::ShellCommandHandlerOptions; pub use test_sync::TestSyncHandler; -pub use tool_search::ToolSearchHandler; +pub(crate) use tool_search::ToolSearchHandlerCache; pub use unified_exec::ExecCommandHandler; pub(crate) use unified_exec::ExecCommandHandlerOptions; pub use unified_exec::WriteStdinHandler; diff --git a/codex-rs/core/src/tools/handlers/tool_search.rs b/codex-rs/core/src/tools/handlers/tool_search.rs index 32c55a578..5c201388b 100644 --- a/codex-rs/core/src/tools/handlers/tool_search.rs +++ b/codex-rs/core/src/tools/handlers/tool_search.rs @@ -16,29 +16,63 @@ use codex_tools::TOOL_SEARCH_TOOL_NAME; use codex_tools::ToolName; use codex_tools::ToolSearchEntry; use codex_tools::ToolSearchInfo; -use codex_tools::ToolSearchSourceInfo; use codex_tools::ToolSpec; use codex_tools::coalesce_loadable_tool_specs; +use std::sync::Arc; +use std::sync::Mutex; pub struct ToolSearchHandler { - entries: Vec, - search_source_infos: Vec, + search_infos: Vec, + spec: ToolSpec, search_engine: SearchEngine, } +#[derive(Default)] +pub(crate) struct ToolSearchHandlerCache { + cached: Mutex>>, +} + +impl ToolSearchHandlerCache { + pub(crate) fn get_or_build(&self, search_infos: Vec) -> Arc { + { + let cached = self.cached(); + if let Some(cached) = cached.as_ref() + && cached.search_infos == search_infos + { + return Arc::clone(cached); + } + } + + let handler = Arc::new(ToolSearchHandler::new(search_infos)); + let mut cached = self.cached(); + if let Some(cached) = cached.as_ref() + && cached.search_infos == handler.search_infos + { + return Arc::clone(cached); + } + + *cached = Some(Arc::clone(&handler)); + handler + } + + fn cached(&self) -> std::sync::MutexGuard<'_, Option>> { + match self.cached.lock() { + Ok(cached) => cached, + Err(poisoned) => poisoned.into_inner(), + } + } +} + impl ToolSearchHandler { pub(crate) fn new(search_infos: Vec) -> Self { - let mut entries = Vec::with_capacity(search_infos.len()); - let mut search_source_infos = Vec::new(); - for search_info in search_infos { - entries.push(search_info.entry); - if let Some(source_info) = search_info.source_info { - search_source_infos.push(source_info); - } - } - let documents: Vec> = entries + let search_source_infos = search_infos .iter() - .map(|entry| entry.search_text.clone()) + .filter_map(|search_info| search_info.source_info.clone()) + .collect::>(); + let spec = create_tool_search_tool(&search_source_infos, TOOL_SEARCH_DEFAULT_LIMIT); + let documents: Vec> = search_infos + .iter() + .map(|search_info| search_info.entry.search_text.clone()) .enumerate() .map(|(idx, search_text)| Document::new(idx, search_text)) .collect(); @@ -46,8 +80,8 @@ impl ToolSearchHandler { SearchEngineBuilder::::with_documents(Language::English, documents).build(); Self { - entries, - search_source_infos, + search_infos, + spec, search_engine, } } @@ -59,7 +93,7 @@ impl ToolExecutor for ToolSearchHandler { } fn spec(&self) -> ToolSpec { - create_tool_search_tool(&self.search_source_infos, TOOL_SEARCH_DEFAULT_LIMIT) + self.spec.clone() } fn supports_parallel_tool_calls(&self) -> bool { @@ -101,7 +135,7 @@ impl ToolSearchHandler { )); } - if self.entries.is_empty() { + if self.search_infos.is_empty() { return Ok(boxed_tool_output(ToolSearchOutput { tools: Vec::new() })); } @@ -124,7 +158,8 @@ impl ToolSearchHandler { .search(query, limit) .into_iter() .map(|result| result.document.id) - .filter_map(|id| self.entries.get(id)); + .filter_map(|id| self.search_infos.get(id)) + .map(|search_info| &search_info.entry); self.search_output_tools(results) } @@ -153,6 +188,29 @@ mod tests { use rmcp::model::Tool; use std::sync::Arc; + #[test] + fn cache_reuses_handler_for_identical_search_infos_and_rebuilds_for_changes() { + let cache = ToolSearchHandlerCache::default(); + let search_infos = vec![ + McpHandler::new(tool_info("calendar", "create_event", "Create events")) + .expect("MCP tool should convert") + .search_info() + .expect("MCP handler should return search info"), + ]; + + let first = cache.get_or_build(search_infos.clone()); + let second = cache.get_or_build(search_infos.clone()); + assert!(Arc::ptr_eq(&first, &second)); + + let mut changed_search_infos = search_infos; + changed_search_infos[0] + .entry + .search_text + .push_str(" changed"); + let changed = cache.get_or_build(changed_search_infos); + assert!(!Arc::ptr_eq(&first, &changed)); + } + #[test] fn mixed_search_results_coalesce_mcp_namespaces() { let dynamic_namespace = DynamicToolNamespaceSpec { @@ -194,9 +252,9 @@ mod tests { })); let handler = ToolSearchHandler::new(search_infos); let results = [ - &handler.entries[0], - &handler.entries[2], - &handler.entries[1], + &handler.search_infos[0].entry, + &handler.search_infos[2].entry, + &handler.search_infos[1].entry, ]; let tools = handler diff --git a/codex-rs/core/src/tools/router.rs b/codex-rs/core/src/tools/router.rs index a095499ea..d7f622c03 100644 --- a/codex-rs/core/src/tools/router.rs +++ b/codex-rs/core/src/tools/router.rs @@ -4,6 +4,7 @@ use crate::session::turn_context::TurnContext; use crate::tools::context::SharedTurnDiffTracker; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; +use crate::tools::handlers::ToolSearchHandlerCache; use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolArgumentDiffConsumer; use crate::tools::registry::ToolRegistry; @@ -45,8 +46,12 @@ pub(crate) struct ToolRouterParams<'a> { } impl ToolRouter { - pub fn from_turn_context(turn_context: &TurnContext, params: ToolRouterParams<'_>) -> Self { - build_tool_router(turn_context, params) + pub(crate) fn from_turn_context( + turn_context: &TurnContext, + params: ToolRouterParams<'_>, + tool_search_handler_cache: &ToolSearchHandlerCache, + ) -> Self { + build_tool_router(turn_context, params, tool_search_handler_cache) } pub(crate) fn from_parts(registry: ToolRegistry, model_visible_specs: Vec) -> Self { diff --git a/codex-rs/core/src/tools/router_tests.rs b/codex-rs/core/src/tools/router_tests.rs index c85c15a50..a41bab91a 100644 --- a/codex-rs/core/src/tools/router_tests.rs +++ b/codex-rs/core/src/tools/router_tests.rs @@ -120,6 +120,7 @@ async fn parallel_support_does_not_match_namespaced_local_tool_names() -> anyhow extension_tool_executors: Vec::new(), dynamic_tools: turn.dynamic_tools.as_slice(), }, + &Default::default(), ); let parallel_tool_name = ["exec_command", "shell_command"] @@ -199,6 +200,7 @@ async fn mcp_parallel_support_uses_handler_data() -> anyhow::Result<()> { extension_tool_executors: Vec::new(), dynamic_tools: turn.dynamic_tools.as_slice(), }, + &Default::default(), ); let call = ToolCall { @@ -234,6 +236,7 @@ async fn tools_without_handlers_do_not_support_parallel() -> anyhow::Result<()> extension_tool_executors: Vec::new(), dynamic_tools: turn.dynamic_tools.as_slice(), }, + &Default::default(), ); assert!(!router.tool_supports_parallel(&ToolCall { @@ -288,6 +291,7 @@ async fn specs_filter_deferred_dynamic_tools() -> anyhow::Result<()> { extension_tool_executors: Vec::new(), dynamic_tools: &dynamic_tools, }, + &Default::default(), ); assert_eq!( @@ -349,6 +353,7 @@ async fn extension_tool_executors_are_model_visible_and_dispatchable() -> anyhow extension_tool_executors: extension_tool_executors(&session), dynamic_tools: turn.dynamic_tools.as_slice(), }, + &Default::default(), ); assert!( diff --git a/codex-rs/core/src/tools/spec_plan.rs b/codex-rs/core/src/tools/spec_plan.rs index f096cddd8..872847611 100644 --- a/codex-rs/core/src/tools/spec_plan.rs +++ b/codex-rs/core/src/tools/spec_plan.rs @@ -23,7 +23,7 @@ use crate::tools::handlers::RequestUserInputHandler; use crate::tools::handlers::ShellCommandHandler; use crate::tools::handlers::ShellCommandHandlerOptions; use crate::tools::handlers::TestSyncHandler; -use crate::tools::handlers::ToolSearchHandler; +use crate::tools::handlers::ToolSearchHandlerCache; use crate::tools::handlers::ViewImageHandler; use crate::tools::handlers::WriteStdinHandler; use crate::tools::handlers::agent_jobs::ReportAgentJobResultHandler; @@ -147,6 +147,7 @@ struct CoreToolPlanContext<'a> { discoverable_tools: Option<&'a [DiscoverableTool]>, extension_tool_executors: &'a [Arc>], dynamic_tools: &'a [DynamicToolSpec], + tool_search_handler_cache: &'a ToolSearchHandlerCache, default_agent_type_description: &'a str, wait_agent_timeouts: WaitAgentTimeoutOptions, } @@ -155,8 +156,10 @@ struct CoreToolPlanContext<'a> { pub(crate) fn build_tool_router( turn_context: &TurnContext, params: ToolRouterParams<'_>, + tool_search_handler_cache: &ToolSearchHandlerCache, ) -> ToolRouter { - let (model_visible_specs, registry) = build_tool_specs_and_registry(turn_context, params); + let (model_visible_specs, registry) = + build_tool_specs_and_registry(turn_context, params, tool_search_handler_cache); ToolRouter::from_parts(registry, model_visible_specs) } @@ -164,6 +167,7 @@ pub(crate) fn build_tool_router( fn build_tool_specs_and_registry( turn_context: &TurnContext, params: ToolRouterParams<'_>, + tool_search_handler_cache: &ToolSearchHandlerCache, ) -> (Vec, ToolRegistry) { let ToolRouterParams { mcp_tools, @@ -181,6 +185,7 @@ fn build_tool_specs_and_registry( discoverable_tools: discoverable_tools.as_deref(), extension_tool_executors: &extension_tool_executors, dynamic_tools, + tool_search_handler_cache, default_agent_type_description: &default_agent_type_description, wait_agent_timeouts: wait_agent_timeout_options(turn_context), }; @@ -875,7 +880,8 @@ fn append_tool_search_executor( return; } - planned_tools.add(ToolSearchHandler::new(search_infos)); + let handler: PlannedRuntime = context.tool_search_handler_cache.get_or_build(search_infos); + planned_tools.add_arc(handler); } fn prepend_code_mode_executors( diff --git a/codex-rs/core/src/tools/spec_plan_tests.rs b/codex-rs/core/src/tools/spec_plan_tests.rs index 0cdc56785..a1bf22310 100644 --- a/codex-rs/core/src/tools/spec_plan_tests.rs +++ b/codex-rs/core/src/tools/spec_plan_tests.rs @@ -32,6 +32,7 @@ use serde_json::json; use crate::session::tests::make_session_and_context; use crate::session::turn_context::TurnContext; +use crate::tools::handlers::ToolSearchHandlerCache; use crate::tools::handlers::multi_agents_spec::MULTI_AGENT_V1_NAMESPACE; use crate::tools::router::ToolRouter; use crate::tools::router::ToolRouterParams; @@ -184,6 +185,7 @@ async fn probe_with( extension_tool_executors: inputs.extension_tool_executors, dynamic_tools: inputs.dynamic_tools.as_slice(), }, + &Default::default(), ); ToolPlanProbe::from_router(router) } @@ -765,6 +767,61 @@ async fn deferred_extension_tools_are_discoverable_with_tool_search() { assert_eq!(plan.exposure("extension_echo"), ToolExposure::Deferred); } +#[tokio::test] +async fn tool_search_cache_rebuilds_when_deferred_sources_change() { + let cache = ToolSearchHandlerCache::default(); + + let (_session, mut first_turn) = make_session_and_context().await; + first_turn.model_info.supports_search_tool = true; + let first_router = ToolRouter::from_turn_context( + &first_turn, + ToolRouterParams { + mcp_tools: None, + deferred_mcp_tools: Some(vec![mcp_tool("first", "mcp__first", "lookup")]), + discoverable_tools: None, + extension_tool_executors: Vec::new(), + dynamic_tools: &[], + }, + &cache, + ); + let first_plan = ToolPlanProbe::from_router(first_router); + + let (_session, mut second_turn) = make_session_and_context().await; + second_turn.model_info.supports_search_tool = true; + let second_router = ToolRouter::from_turn_context( + &second_turn, + ToolRouterParams { + mcp_tools: None, + deferred_mcp_tools: Some(vec![mcp_tool("second", "mcp__second", "lookup")]), + discoverable_tools: None, + extension_tool_executors: Vec::new(), + dynamic_tools: &[], + }, + &cache, + ); + let second_plan = ToolPlanProbe::from_router(second_router); + + let ToolSpec::ToolSearch { + description: first_description, + .. + } = first_plan.visible_spec("tool_search") + else { + panic!("expected first tool_search spec"); + }; + assert!(first_description.contains("- first: Tools from first.")); + assert!(!first_description.contains("- second: Tools from second.")); + + let ToolSpec::ToolSearch { + description: second_description, + .. + } = second_plan.visible_spec("tool_search") + else { + panic!("expected second tool_search spec"); + }; + assert!(second_description.contains("- second: Tools from second.")); + assert!(!second_description.contains("- first: Tools from first.")); +} + #[tokio::test] async fn invalid_mcp_tools_are_not_registered() { let plan = probe_with( diff --git a/codex-rs/tools/src/tool_search.rs b/codex-rs/tools/src/tool_search.rs index 1083d9ea9..9a1c8852d 100644 --- a/codex-rs/tools/src/tool_search.rs +++ b/codex-rs/tools/src/tool_search.rs @@ -6,13 +6,13 @@ use crate::ToolSearchSourceInfo; use crate::ToolSpec; use crate::default_namespace_description; -#[derive(Clone)] +#[derive(Clone, PartialEq)] pub struct ToolSearchEntry { pub search_text: String, pub output: LoadableToolSpec, } -#[derive(Clone)] +#[derive(Clone, PartialEq)] pub struct ToolSearchInfo { pub entry: ToolSearchEntry, pub source_info: Option,