diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 46bb21bf6..17b2696cb 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2921,7 +2921,6 @@ dependencies = [ name = "codex-extension-api" version = "0.0.0" dependencies = [ - "async-trait", "codex-config", "codex-context-fragments", "codex-protocol", @@ -3070,7 +3069,6 @@ dependencies = [ name = "codex-guardian" version = "0.0.0" dependencies = [ - "async-trait", "codex-core", "codex-extension-api", "codex-protocol", @@ -3255,7 +3253,6 @@ dependencies = [ name = "codex-mcp-extension" version = "0.0.0" dependencies = [ - "async-trait", "codex-config", "codex-core", "codex-core-plugins", diff --git a/codex-rs/core/src/session/tests.rs b/codex-rs/core/src/session/tests.rs index f6166260a..b83217293 100644 --- a/codex-rs/core/src/session/tests.rs +++ b/codex-rs/core/src/session/tests.rs @@ -1922,26 +1922,27 @@ async fn record_token_usage_info_notifies_extension_contributors() { records: Arc>>, } - #[async_trait::async_trait] impl codex_extension_api::TokenUsageContributor for TokenUsageRecorder { - async fn on_token_usage( - &self, - session_store: &codex_extension_api::ExtensionData, - thread_store: &codex_extension_api::ExtensionData, - turn_store: &codex_extension_api::ExtensionData, - token_usage: &TokenUsageInfo, - ) { - self.records - .lock() - .expect("token usage records lock") - .push(RecordedTokenUsage { - session_level_id: session_store.level_id().to_string(), - thread_level_id: thread_store.level_id().to_string(), - turn_level_id: turn_store.level_id().to_string(), - token_usage: token_usage.clone(), - saw_session_store: session_store.get::().is_some(), - saw_thread_store: thread_store.get::().is_some(), - }); + fn on_token_usage<'a>( + &'a self, + session_store: &'a codex_extension_api::ExtensionData, + thread_store: &'a codex_extension_api::ExtensionData, + turn_store: &'a codex_extension_api::ExtensionData, + token_usage: &'a TokenUsageInfo, + ) -> codex_extension_api::ExtensionFuture<'a, ()> { + Box::pin(async move { + self.records + .lock() + .expect("token usage records lock") + .push(RecordedTokenUsage { + session_level_id: session_store.level_id().to_string(), + thread_level_id: thread_store.level_id().to_string(), + turn_level_id: turn_store.level_id().to_string(), + token_usage: token_usage.clone(), + saw_session_store: session_store.get::().is_some(), + saw_thread_store: thread_store.get::().is_some(), + }); + }) } } @@ -2040,25 +2041,32 @@ async fn turn_start_lifecycle_exposes_turn_metadata_and_token_baseline() { records: Arc>>, } - #[async_trait::async_trait] impl codex_extension_api::TurnLifecycleContributor for TurnStartRecorder { - async fn on_turn_start(&self, input: codex_extension_api::TurnStartInput<'_>) { - self.records - .lock() - .expect("turn start records lock") - .push(RecordedTurnStart { - session_level_id: input.session_store.level_id().to_string(), - thread_level_id: input.thread_store.level_id().to_string(), - turn_level_id: input.turn_store.level_id().to_string(), - turn_id: input.turn_id.to_string(), - collaboration_mode: input.collaboration_mode.clone(), - token_usage_at_turn_start: input.token_usage_at_turn_start.clone(), - saw_session_store: input - .session_store - .get::() - .is_some(), - saw_thread_store: input.thread_store.get::().is_some(), - }); + fn on_turn_start<'a>( + &'a self, + input: codex_extension_api::TurnStartInput<'a>, + ) -> codex_extension_api::ExtensionFuture<'a, ()> { + Box::pin(async move { + self.records + .lock() + .expect("turn start records lock") + .push(RecordedTurnStart { + session_level_id: input.session_store.level_id().to_string(), + thread_level_id: input.thread_store.level_id().to_string(), + turn_level_id: input.turn_store.level_id().to_string(), + turn_id: input.turn_id.to_string(), + collaboration_mode: input.collaboration_mode.clone(), + token_usage_at_turn_start: input.token_usage_at_turn_start.clone(), + saw_session_store: input + .session_store + .get::() + .is_some(), + saw_thread_store: input + .thread_store + .get::() + .is_some(), + }); + }) } } @@ -2138,24 +2146,31 @@ async fn turn_error_lifecycle_exposes_error_and_stores() { records: Arc>>, } - #[async_trait::async_trait] impl codex_extension_api::TurnLifecycleContributor for TurnErrorRecorder { - async fn on_turn_error(&self, input: codex_extension_api::TurnErrorInput<'_>) { - self.records - .lock() - .expect("turn error records lock") - .push(RecordedTurnError { - session_level_id: input.session_store.level_id().to_string(), - thread_level_id: input.thread_store.level_id().to_string(), - turn_level_id: input.turn_store.level_id().to_string(), - turn_id: input.turn_id.to_string(), - error: input.error, - saw_session_store: input - .session_store - .get::() - .is_some(), - saw_thread_store: input.thread_store.get::().is_some(), - }); + fn on_turn_error<'a>( + &'a self, + input: codex_extension_api::TurnErrorInput<'a>, + ) -> codex_extension_api::ExtensionFuture<'a, ()> { + Box::pin(async move { + self.records + .lock() + .expect("turn error records lock") + .push(RecordedTurnError { + session_level_id: input.session_store.level_id().to_string(), + thread_level_id: input.thread_store.level_id().to_string(), + turn_level_id: input.turn_store.level_id().to_string(), + turn_id: input.turn_id.to_string(), + error: input.error, + saw_session_store: input + .session_store + .get::() + .is_some(), + saw_thread_store: input + .thread_store + .get::() + .is_some(), + }); + }) } } @@ -6411,16 +6426,20 @@ async fn submission_loop_channel_close_emits_thread_stop_lifecycle() { expected_thread_id: ThreadId, } - #[async_trait::async_trait] impl codex_extension_api::ThreadLifecycleContributor for ThreadStopRecorder { - async fn on_thread_stop(&self, input: codex_extension_api::ThreadStopInput<'_>) { - assert_eq!( - self.expected_thread_id.to_string(), - input.thread_store.level_id() - ); - assert!(input.session_store.get::().is_some()); - assert!(input.thread_store.get::().is_some()); - self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + fn on_thread_stop<'a>( + &'a self, + input: codex_extension_api::ThreadStopInput<'a>, + ) -> codex_extension_api::ExtensionFuture<'a, ()> { + Box::pin(async move { + assert_eq!( + self.expected_thread_id.to_string(), + input.thread_store.level_id() + ); + assert!(input.session_store.get::().is_some()); + assert!(input.thread_store.get::().is_some()); + self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + }) } } @@ -6457,33 +6476,41 @@ async fn submission_loop_channel_close_aborts_active_turn_before_thread_stop_lif expected_turn_id: String, } - #[async_trait::async_trait] impl codex_extension_api::ThreadLifecycleContributor for LifecycleRecorder { - async fn on_thread_stop(&self, input: codex_extension_api::ThreadStopInput<'_>) { - assert_eq!( - self.expected_thread_id.to_string(), - input.thread_store.level_id() - ); - self.calls - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .push("thread_stop"); + fn on_thread_stop<'a>( + &'a self, + input: codex_extension_api::ThreadStopInput<'a>, + ) -> codex_extension_api::ExtensionFuture<'a, ()> { + Box::pin(async move { + assert_eq!( + self.expected_thread_id.to_string(), + input.thread_store.level_id() + ); + self.calls + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .push("thread_stop"); + }) } } - #[async_trait::async_trait] impl codex_extension_api::TurnLifecycleContributor for LifecycleRecorder { - async fn on_turn_abort(&self, input: codex_extension_api::TurnAbortInput<'_>) { - assert_eq!( - self.expected_thread_id.to_string(), - input.thread_store.level_id() - ); - assert_eq!(self.expected_turn_id, input.turn_store.level_id()); - assert_eq!(TurnAbortReason::Interrupted, input.reason); - self.calls - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .push("turn_abort"); + fn on_turn_abort<'a>( + &'a self, + input: codex_extension_api::TurnAbortInput<'a>, + ) -> codex_extension_api::ExtensionFuture<'a, ()> { + Box::pin(async move { + assert_eq!( + self.expected_thread_id.to_string(), + input.thread_store.level_id() + ); + assert_eq!(self.expected_turn_id, input.turn_store.level_id()); + assert_eq!(TurnAbortReason::Interrupted, input.reason); + self.calls + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .push("turn_abort"); + }) } } @@ -8698,15 +8725,19 @@ async fn task_finish_emits_thread_idle_lifecycle_after_active_turn_clears() { expected_thread_id: ThreadId, } - #[async_trait::async_trait] impl codex_extension_api::ThreadLifecycleContributor for ThreadIdleRecorder { - async fn on_thread_idle(&self, input: codex_extension_api::ThreadIdleInput<'_>) { - assert_eq!( - self.expected_thread_id.to_string(), - input.thread_store.level_id() - ); - self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - self.idle_tx.send(()).await.expect("idle receiver open"); + fn on_thread_idle<'a>( + &'a self, + input: codex_extension_api::ThreadIdleInput<'a>, + ) -> codex_extension_api::ExtensionFuture<'a, ()> { + Box::pin(async move { + assert_eq!( + self.expected_thread_id.to_string(), + input.thread_store.level_id() + ); + self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + self.idle_tx.send(()).await.expect("idle receiver open"); + }) } } @@ -8740,10 +8771,14 @@ async fn thread_idle_lifecycle_waits_for_trigger_turn_mailbox_work() { calls: Arc, } - #[async_trait::async_trait] impl codex_extension_api::ThreadLifecycleContributor for ThreadIdleRecorder { - async fn on_thread_idle(&self, _input: codex_extension_api::ThreadIdleInput<'_>) { - self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + fn on_thread_idle<'a>( + &'a self, + _input: codex_extension_api::ThreadIdleInput<'a>, + ) -> codex_extension_api::ExtensionFuture<'a, ()> { + Box::pin(async move { + self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + }) } } diff --git a/codex-rs/core/src/session/turn_tests.rs b/codex-rs/core/src/session/turn_tests.rs index 81a43ff5a..3511c03bd 100644 --- a/codex-rs/core/src/session/turn_tests.rs +++ b/codex-rs/core/src/session/turn_tests.rs @@ -7,20 +7,21 @@ use std::sync::Arc; struct RewriteAgentMessageContributor; -#[async_trait::async_trait] impl TurnItemContributor for RewriteAgentMessageContributor { - async fn contribute( - &self, - _thread_store: &ExtensionData, - _turn_store: &ExtensionData, - item: &mut TurnItem, - ) -> Result<(), String> { - if let TurnItem::AgentMessage(agent_message) = item { - agent_message.content = vec![AgentMessageContent::Text { - text: "plan contributed assistant text".to_string(), - }]; - } - Ok(()) + fn contribute<'a>( + &'a self, + _thread_store: &'a ExtensionData, + _turn_store: &'a ExtensionData, + item: &'a mut TurnItem, + ) -> codex_extension_api::ExtensionFuture<'a, Result<(), String>> { + Box::pin(async move { + if let TurnItem::AgentMessage(agent_message) = item { + agent_message.content = vec![AgentMessageContent::Text { + text: "plan contributed assistant text".to_string(), + }]; + } + Ok(()) + }) } } diff --git a/codex-rs/core/src/stream_events_utils_tests.rs b/codex-rs/core/src/stream_events_utils_tests.rs index 9c4be9457..c151a46fb 100644 --- a/codex-rs/core/src/stream_events_utils_tests.rs +++ b/codex-rs/core/src/stream_events_utils_tests.rs @@ -167,41 +167,43 @@ struct TestTurnItemContributor; #[derive(Debug)] struct TurnItemContributorRan; -#[async_trait::async_trait] impl TurnItemContributor for TestTurnItemContributor { - async fn contribute( - &self, - _thread_store: &ExtensionData, - turn_store: &ExtensionData, - item: &mut TurnItem, - ) -> Result<(), String> { - turn_store.insert(TurnItemContributorRan); - if let TurnItem::AgentMessage(agent_message) = item { - agent_message.memory_citation = Some(MemoryCitation { - entries: Vec::new(), - rollout_ids: Vec::new(), - }); - } - Ok(()) + fn contribute<'a>( + &'a self, + _thread_store: &'a ExtensionData, + turn_store: &'a ExtensionData, + item: &'a mut TurnItem, + ) -> codex_extension_api::ExtensionFuture<'a, Result<(), String>> { + Box::pin(async move { + turn_store.insert(TurnItemContributorRan); + if let TurnItem::AgentMessage(agent_message) = item { + agent_message.memory_citation = Some(MemoryCitation { + entries: Vec::new(), + rollout_ids: Vec::new(), + }); + } + Ok(()) + }) } } struct RewriteAgentMessageContributor; -#[async_trait::async_trait] impl TurnItemContributor for RewriteAgentMessageContributor { - async fn contribute( - &self, - _thread_store: &ExtensionData, - _turn_store: &ExtensionData, - item: &mut TurnItem, - ) -> Result<(), String> { - if let TurnItem::AgentMessage(agent_message) = item { - agent_message.content = vec![AgentMessageContent::Text { - text: "contributed assistant text".to_string(), - }]; - } - Ok(()) + fn contribute<'a>( + &'a self, + _thread_store: &'a ExtensionData, + _turn_store: &'a ExtensionData, + item: &'a mut TurnItem, + ) -> codex_extension_api::ExtensionFuture<'a, Result<(), String>> { + Box::pin(async move { + if let TurnItem::AgentMessage(agent_message) = item { + agent_message.content = vec![AgentMessageContent::Text { + text: "contributed assistant text".to_string(), + }]; + } + Ok(()) + }) } } diff --git a/codex-rs/core/src/thread_manager_tests.rs b/codex-rs/core/src/thread_manager_tests.rs index bb0c0af31..48ca2598a 100644 --- a/codex-rs/core/src/thread_manager_tests.rs +++ b/codex-rs/core/src/thread_manager_tests.rs @@ -394,20 +394,24 @@ async fn start_thread_seeds_extension_data_before_lifecycle_contributors_run() { observed: Arc>>, } - #[async_trait::async_trait] impl codex_extension_api::ThreadLifecycleContributor for InitialDataRecorder { - async fn on_thread_start(&self, input: codex_extension_api::ThreadStartInput<'_, Config>) { - let marker = input - .thread_store - .get::() - .expect("initial extension data should be available"); - *self - .observed - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(( - input.thread_store.level_id().to_string(), - marker.0.to_string(), - )); + fn on_thread_start<'a>( + &'a self, + input: codex_extension_api::ThreadStartInput<'a, Config>, + ) -> codex_extension_api::ExtensionFuture<'a, ()> { + Box::pin(async move { + let marker = input + .thread_store + .get::() + .expect("initial extension data should be available"); + *self + .observed + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(( + input.thread_store.level_id().to_string(), + marker.0.to_string(), + )); + }) } } diff --git a/codex-rs/core/src/tools/handlers/extension_tools.rs b/codex-rs/core/src/tools/handlers/extension_tools.rs index ecbf772e9..5e177156c 100644 --- a/codex-rs/core/src/tools/handlers/extension_tools.rs +++ b/codex-rs/core/src/tools/handlers/extension_tools.rs @@ -388,16 +388,17 @@ mod tests { struct RecordExtensionTurnItemContributor; - #[async_trait::async_trait] impl TurnItemContributor for RecordExtensionTurnItemContributor { - async fn contribute( - &self, - _thread_store: &ExtensionData, - turn_store: &ExtensionData, - _item: &mut TurnItem, - ) -> Result<(), String> { - turn_store.insert(ExtensionTurnItemContributorRan); - Ok(()) + fn contribute<'a>( + &'a self, + _thread_store: &'a ExtensionData, + turn_store: &'a ExtensionData, + _item: &'a mut TurnItem, + ) -> codex_extension_api::ExtensionFuture<'a, Result<(), String>> { + Box::pin(async move { + turn_store.insert(ExtensionTurnItemContributorRan); + Ok(()) + }) } } diff --git a/codex-rs/ext/extension-api/Cargo.toml b/codex-rs/ext/extension-api/Cargo.toml index 2e155b430..25ed8d0d7 100644 --- a/codex-rs/ext/extension-api/Cargo.toml +++ b/codex-rs/ext/extension-api/Cargo.toml @@ -14,7 +14,6 @@ doctest = false workspace = true [dependencies] -async-trait = { workspace = true } codex-config = { workspace = true } codex-context-fragments = { workspace = true } codex-protocol = { workspace = true } diff --git a/codex-rs/ext/extension-api/src/contributors.rs b/codex-rs/ext/extension-api/src/contributors.rs index 040329437..e93088b5d 100644 --- a/codex-rs/ext/extension-api/src/contributors.rs +++ b/codex-rs/ext/extension-api/src/contributors.rs @@ -1,4 +1,5 @@ use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use codex_context_fragments::ContextualUserFragment; @@ -36,6 +37,9 @@ pub use turn_lifecycle::TurnErrorInput; pub use turn_lifecycle::TurnStartInput; pub use turn_lifecycle::TurnStopInput; +/// Boxed, sendable future returned by asynchronous extension contributors. +pub type ExtensionFuture<'a, T> = Pin + Send + 'a>>; + /// Extension contribution that resolves runtime MCP servers from host config. /// /// Contributors run in registration order. Later contributions for the same @@ -43,9 +47,8 @@ pub use turn_lifecycle::TurnStopInput; /// own and must apply any source-specific policy before returning a server. /// Plugin-owned servers and their provenance continue to be resolved by the /// plugin manager until that ownership moves into an extension explicitly. -#[async_trait::async_trait] pub trait McpServerContributor: Send + Sync { - async fn contribute(&self, config: &C) -> Vec; + fn contribute<'a>(&'a self, config: &'a C) -> ExtensionFuture<'a, Vec>; } /// Extension contribution that adds prompt fragments during prompt assembly. @@ -54,7 +57,7 @@ pub trait ContextContributor: Send + Sync { &'a self, session_store: &'a ExtensionData, thread_store: &'a ExtensionData, - ) -> std::pin::Pin> + Send + 'a>>; + ) -> ExtensionFuture<'a, Vec>; } /// Contributor for host-owned thread lifecycle gates. @@ -62,24 +65,43 @@ pub trait ContextContributor: Send + Sync { /// Implementations should use these callbacks to seed, rehydrate, or flush /// extension-private thread state. Heavy dependencies belong on the extension /// value created by the host, not in these inputs. -#[async_trait::async_trait] pub trait ThreadLifecycleContributor: Send + Sync { /// Called after thread-scoped extension stores are created, before later /// contributors can read from them. - async fn on_thread_start(&self, _input: ThreadStartInput<'_, C>) {} + fn on_thread_start<'a>(&'a self, input: ThreadStartInput<'a, C>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let _self = self; + let _input = input; + }) + } /// Called after the host constructs a runtime from persisted history. - async fn on_thread_resume(&self, _input: ThreadResumeInput<'_>) {} + fn on_thread_resume<'a>(&'a self, input: ThreadResumeInput<'a>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let _self = self; + let _input = input; + }) + } /// Called after the host has drained immediately pending thread work. /// /// Implementations may use host capabilities captured by the extension to /// submit follow-up input. The host remains responsible for deciding /// whether that input starts a turn, is queued, or is ignored. - async fn on_thread_idle(&self, _input: ThreadIdleInput<'_>) {} + fn on_thread_idle<'a>(&'a self, input: ThreadIdleInput<'a>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let _self = self; + let _input = input; + }) + } /// Called before the host drops the thread runtime and thread-scoped store. - async fn on_thread_stop(&self, _input: ThreadStopInput<'_>) {} + fn on_thread_stop<'a>(&'a self, input: ThreadStopInput<'a>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let _self = self; + let _input = input; + }) + } } /// Contributor for host-owned turn lifecycle gates. @@ -87,20 +109,39 @@ pub trait ThreadLifecycleContributor: Send + Sync { /// Implementations should use these callbacks to seed, observe, or clear /// extension-private turn state. The host exposes stable identifiers and /// extension stores instead of core runtime objects. -#[async_trait::async_trait] pub trait TurnLifecycleContributor: Send + Sync { /// Called after turn-scoped extension stores are created, before the task /// for the turn starts running. - async fn on_turn_start(&self, _input: TurnStartInput<'_>) {} + fn on_turn_start<'a>(&'a self, input: TurnStartInput<'a>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let _self = self; + let _input = input; + }) + } /// Called before the host drops the completed turn runtime and turn store. - async fn on_turn_stop(&self, _input: TurnStopInput<'_>) {} + fn on_turn_stop<'a>(&'a self, input: TurnStopInput<'a>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let _self = self; + let _input = input; + }) + } /// Called after the host aborts a running turn. - async fn on_turn_abort(&self, _input: TurnAbortInput<'_>) {} + fn on_turn_abort<'a>(&'a self, input: TurnAbortInput<'a>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let _self = self; + let _input = input; + }) + } /// Called when the host observes an error for a running turn. - async fn on_turn_error(&self, _input: TurnErrorInput<'_>) {} + fn on_turn_error<'a>(&'a self, input: TurnErrorInput<'a>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let _self = self; + let _input = input; + }) + } } /// Extension contribution that can add turn-local model input. @@ -109,16 +150,15 @@ pub trait TurnLifecycleContributor: Send + Sync { /// must preserve authority boundaries for external resources. Expensive or /// host-specific dependencies belong on the extension value installed by the /// host, not in this input. -#[async_trait::async_trait] pub trait TurnInputContributor: Send + Sync { /// Returns additional contextual fragments for one submitted turn. - async fn contribute( - &self, + fn contribute<'a>( + &'a self, input: TurnInputContext, - session_store: &ExtensionData, - thread_store: &ExtensionData, - turn_store: &ExtensionData, - ) -> Vec>; + session_store: &'a ExtensionData, + thread_store: &'a ExtensionData, + turn_store: &'a ExtensionData, + ) -> ExtensionFuture<'a, Vec>>; } /// Contributor for host-owned configuration changes. @@ -142,16 +182,19 @@ pub trait ConfigContributor: Send + Sync { /// Implementations should keep this callback cheap. The host calls it after /// updating cached token usage and before emitting the corresponding client /// token-count notification. -#[async_trait::async_trait] pub trait TokenUsageContributor: Send + Sync { /// Called each time the host records token usage from a model response. - async fn on_token_usage( - &self, - _session_store: &ExtensionData, - _thread_store: &ExtensionData, - _turn_store: &ExtensionData, - _token_usage: &TokenUsageInfo, - ) { + fn on_token_usage<'a>( + &'a self, + _session_store: &'a ExtensionData, + _thread_store: &'a ExtensionData, + _turn_store: &'a ExtensionData, + _token_usage: &'a TokenUsageInfo, + ) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let _self = self; + let _inputs = (_session_store, _thread_store, _turn_store, _token_usage); + }) } } @@ -183,14 +226,13 @@ pub trait ToolLifecycleContributor: Send + Sync { } /// Extension contribution that can claim rendered approval-review prompts. -#[async_trait::async_trait] pub trait ApprovalReviewContributor: Send + Sync { - async fn contribute( - &self, - session_store: &ExtensionData, - thread_store: &ExtensionData, - prompt: &str, - ) -> Option; + fn contribute<'a>( + &'a self, + session_store: &'a ExtensionData, + thread_store: &'a ExtensionData, + prompt: &'a str, + ) -> ExtensionFuture<'a, Option>; } /// Ordered post-processing contribution for one parsed turn item. @@ -198,12 +240,11 @@ pub trait ApprovalReviewContributor: Send + Sync { /// Implementations may mutate the item before it is emitted and may use the /// explicitly exposed thread- and turn-lifetime stores when they need durable /// extension-private state. -#[async_trait::async_trait] pub trait TurnItemContributor: Send + Sync { - async fn contribute( - &self, - thread_store: &ExtensionData, - turn_store: &ExtensionData, - item: &mut TurnItem, - ) -> Result<(), String>; + fn contribute<'a>( + &'a self, + thread_store: &'a ExtensionData, + turn_store: &'a ExtensionData, + item: &'a mut TurnItem, + ) -> ExtensionFuture<'a, Result<(), String>>; } diff --git a/codex-rs/ext/extension-api/src/lib.rs b/codex-rs/ext/extension-api/src/lib.rs index 8b269f31d..dc3495a71 100644 --- a/codex-rs/ext/extension-api/src/lib.rs +++ b/codex-rs/ext/extension-api/src/lib.rs @@ -31,6 +31,7 @@ pub use codex_tools::parse_tool_input_schema_without_compaction; pub use contributors::ApprovalReviewContributor; pub use contributors::ConfigContributor; pub use contributors::ContextContributor; +pub use contributors::ExtensionFuture; pub use contributors::McpServerContribution; pub use contributors::McpServerContributor; pub use contributors::PromptFragment; diff --git a/codex-rs/ext/extension-api/tests/registry.rs b/codex-rs/ext/extension-api/tests/registry.rs index d26f1010c..42cb18af6 100644 --- a/codex-rs/ext/extension-api/tests/registry.rs +++ b/codex-rs/ext/extension-api/tests/registry.rs @@ -1,5 +1,3 @@ -use std::future::Future; -use std::pin::Pin; use std::sync::Arc; use std::sync::Mutex; @@ -9,6 +7,7 @@ use codex_extension_api::ContextContributor; use codex_extension_api::ContextualUserFragment; use codex_extension_api::ExtensionData; use codex_extension_api::ExtensionEventSink; +use codex_extension_api::ExtensionFuture; use codex_extension_api::ExtensionRegistryBuilder; use codex_extension_api::PromptFragment; use codex_extension_api::ThreadLifecycleContributor; @@ -37,32 +36,32 @@ impl ContextContributor for AllContributors { &'a self, _session_store: &'a ExtensionData, _thread_store: &'a ExtensionData, - ) -> Pin> + Send + 'a>> { + ) -> ExtensionFuture<'a, Vec> { Box::pin(std::future::ready(Vec::new())) } } -#[async_trait::async_trait] impl ThreadLifecycleContributor<()> for AllContributors {} -#[async_trait::async_trait] impl TurnLifecycleContributor for AllContributors {} impl ConfigContributor<()> for AllContributors {} -#[async_trait::async_trait] impl TokenUsageContributor for AllContributors {} -#[async_trait::async_trait] impl TurnInputContributor for AllContributors { - async fn contribute( - &self, - _input: TurnInputContext, - _session_store: &ExtensionData, - _thread_store: &ExtensionData, - _turn_store: &ExtensionData, - ) -> Vec> { - Vec::new() + fn contribute<'a>( + &'a self, + input: TurnInputContext, + _session_store: &'a ExtensionData, + _thread_store: &'a ExtensionData, + _turn_store: &'a ExtensionData, + ) -> ExtensionFuture<'a, Vec>> { + Box::pin(async move { + let _self = self; + let _input = input; + Vec::new() + }) } } @@ -78,27 +77,31 @@ impl ToolContributor for AllContributors { impl ToolLifecycleContributor for AllContributors {} -#[async_trait::async_trait] impl TurnItemContributor for AllContributors { - async fn contribute( - &self, - _thread_store: &ExtensionData, - _turn_store: &ExtensionData, - _item: &mut TurnItem, - ) -> Result<(), String> { - Ok(()) + fn contribute<'a>( + &'a self, + _thread_store: &'a ExtensionData, + _turn_store: &'a ExtensionData, + _item: &'a mut TurnItem, + ) -> ExtensionFuture<'a, Result<(), String>> { + Box::pin(async move { + let _self = self; + Ok(()) + }) } } -#[async_trait::async_trait] impl ApprovalReviewContributor for AllContributors { - async fn contribute( - &self, - _session_store: &ExtensionData, - _thread_store: &ExtensionData, - _prompt: &str, - ) -> Option { - Some(ReviewDecision::ApprovedForSession) + fn contribute<'a>( + &'a self, + _session_store: &'a ExtensionData, + _thread_store: &'a ExtensionData, + _prompt: &'a str, + ) -> ExtensionFuture<'a, Option> { + Box::pin(async move { + let _self = self; + Some(ReviewDecision::ApprovedForSession) + }) } } @@ -146,7 +149,7 @@ impl ContextContributor for NamedContextContributor { &'a self, _session_store: &'a ExtensionData, _thread_store: &'a ExtensionData, - ) -> Pin> + Send + 'a>> { + ) -> ExtensionFuture<'a, Vec> { Box::pin(std::future::ready(vec![PromptFragment::developer_policy( self.0, )])) @@ -158,19 +161,20 @@ struct RecordingTurnItemContributor { calls: Arc>>, } -#[async_trait::async_trait] impl TurnItemContributor for RecordingTurnItemContributor { - async fn contribute( - &self, - _thread_store: &ExtensionData, - _turn_store: &ExtensionData, - _item: &mut TurnItem, - ) -> Result<(), String> { - self.calls - .lock() - .unwrap_or_else(|error| panic!("turn item calls lock poisoned: {error}")) - .push(self.name); - Ok(()) + fn contribute<'a>( + &'a self, + _thread_store: &'a ExtensionData, + _turn_store: &'a ExtensionData, + _item: &'a mut TurnItem, + ) -> ExtensionFuture<'a, Result<(), String>> { + Box::pin(async move { + self.calls + .lock() + .unwrap_or_else(|error| panic!("turn item calls lock poisoned: {error}")) + .push(self.name); + Ok(()) + }) } } @@ -236,24 +240,25 @@ struct RecordingApprovalContributor { calls: Arc>>, } -#[async_trait::async_trait] impl ApprovalReviewContributor for RecordingApprovalContributor { - async fn contribute( - &self, - session_store: &ExtensionData, - thread_store: &ExtensionData, - prompt: &str, - ) -> Option { - self.calls - .lock() - .unwrap_or_else(|error| panic!("approval calls lock poisoned: {error}")) - .push(ApprovalCall { - contributor: self.name, - session_id: session_store.level_id().to_string(), - thread_id: thread_store.level_id().to_string(), - prompt: prompt.to_string(), - }); - self.decision.clone() + fn contribute<'a>( + &'a self, + session_store: &'a ExtensionData, + thread_store: &'a ExtensionData, + prompt: &'a str, + ) -> ExtensionFuture<'a, Option> { + Box::pin(async move { + self.calls + .lock() + .unwrap_or_else(|error| panic!("approval calls lock poisoned: {error}")) + .push(ApprovalCall { + contributor: self.name, + session_id: session_store.level_id().to_string(), + thread_id: thread_store.level_id().to_string(), + prompt: prompt.to_string(), + }); + self.decision.clone() + }) } } diff --git a/codex-rs/ext/goal/src/extension.rs b/codex-rs/ext/goal/src/extension.rs index 0696c2b86..8e5b20e7f 100644 --- a/codex-rs/ext/goal/src/extension.rs +++ b/codex-rs/ext/goal/src/extension.rs @@ -1,12 +1,12 @@ use std::sync::Arc; use std::sync::Weak; -use async_trait::async_trait; use codex_analytics::AnalyticsEventsClient; use codex_core::ThreadManager; use codex_extension_api::ConfigContributor; use codex_extension_api::ExtensionData; use codex_extension_api::ExtensionEventSink; +use codex_extension_api::ExtensionFuture; use codex_extension_api::ExtensionRegistryBuilder; use codex_extension_api::ThreadIdleInput; use codex_extension_api::ThreadLifecycleContributor; @@ -95,76 +95,83 @@ impl GoalExtension { } } -#[async_trait] impl ThreadLifecycleContributor for GoalExtension where C: Send + Sync + 'static, { - async fn on_thread_start(&self, input: ThreadStartInput<'_, C>) { - let enabled = (self.goals_enabled)(input.config); - let tools_available_for_thread = input.persistent_thread_state_available - && !matches!( - input.session_source, - SessionSource::SubAgent(SubAgentSource::Review) - ); - input - .thread_store - .insert(GoalExtensionConfig::from_enabled(enabled)); - let accounting_state = input - .thread_store - .get_or_init::(GoalAccountingState::default); - let Ok(thread_id) = ThreadId::from_string(input.thread_store.level_id()) else { - return; - }; - let runtime = input.thread_store.get_or_init::(|| { - GoalRuntimeHandle::new( - thread_id, - Arc::clone(&self.state_dbs), - self.event_emitter.clone(), - self.metrics.clone(), - self.thread_manager.clone(), - accounting_state, - GoalRuntimeConfig { - analytics: self.analytics.clone(), - enabled, - tools_available_for_thread, - }, - ) - }); - runtime.set_enabled(enabled); - self.goal_service.register_runtime(&runtime); + fn on_thread_start<'a>(&'a self, input: ThreadStartInput<'a, C>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let enabled = (self.goals_enabled)(input.config); + let tools_available_for_thread = input.persistent_thread_state_available + && !matches!( + input.session_source, + SessionSource::SubAgent(SubAgentSource::Review) + ); + input + .thread_store + .insert(GoalExtensionConfig::from_enabled(enabled)); + let accounting_state = input + .thread_store + .get_or_init::(GoalAccountingState::default); + let Ok(thread_id) = ThreadId::from_string(input.thread_store.level_id()) else { + return; + }; + let runtime = input.thread_store.get_or_init::(|| { + GoalRuntimeHandle::new( + thread_id, + Arc::clone(&self.state_dbs), + self.event_emitter.clone(), + self.metrics.clone(), + self.thread_manager.clone(), + accounting_state, + GoalRuntimeConfig { + analytics: self.analytics.clone(), + enabled, + tools_available_for_thread, + }, + ) + }); + runtime.set_enabled(enabled); + self.goal_service.register_runtime(&runtime); + }) } - async fn on_thread_resume(&self, input: ThreadResumeInput<'_>) { - let Some(runtime) = goal_runtime_handle(input.thread_store) else { - return; - }; + fn on_thread_resume<'a>(&'a self, input: ThreadResumeInput<'a>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let Some(runtime) = goal_runtime_handle(input.thread_store) else { + return; + }; - if let Err(err) = runtime.restore_after_resume().await { - tracing::warn!( - "failed to restore goal runtime after thread resume for {}: {err}", - runtime.thread_id() - ); - } + if let Err(err) = runtime.restore_after_resume().await { + tracing::warn!( + "failed to restore goal runtime after thread resume for {}: {err}", + runtime.thread_id() + ); + } + }) } - async fn on_thread_idle(&self, input: ThreadIdleInput<'_>) { - let Some(runtime) = goal_runtime_handle(input.thread_store) else { - return; - }; + fn on_thread_idle<'a>(&'a self, input: ThreadIdleInput<'a>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let Some(runtime) = goal_runtime_handle(input.thread_store) else { + return; + }; - if let Err(err) = runtime.continue_if_idle().await { - tracing::warn!( - "failed to continue active goal for idle thread {}: {err}", - runtime.thread_id() - ); - } + if let Err(err) = runtime.continue_if_idle().await { + tracing::warn!( + "failed to continue active goal for idle thread {}: {err}", + runtime.thread_id() + ); + } + }) } - async fn on_thread_stop(&self, input: ThreadStopInput<'_>) { - if let Some(runtime) = goal_runtime_handle(input.thread_store) { - self.goal_service.unregister_runtime(&runtime); - } + fn on_thread_stop<'a>(&'a self, input: ThreadStopInput<'a>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + if let Some(runtime) = goal_runtime_handle(input.thread_store) { + self.goal_service.unregister_runtime(&runtime); + } + }) } } @@ -187,153 +194,161 @@ where } } -#[async_trait] impl TurnLifecycleContributor for GoalExtension where C: Send + Sync + 'static, { - async fn on_turn_start(&self, input: TurnStartInput<'_>) { - let Some(runtime) = goal_runtime_handle(input.thread_store) else { - return; - }; - if !runtime.is_enabled() { - return; - } + fn on_turn_start<'a>(&'a self, input: TurnStartInput<'a>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let Some(runtime) = goal_runtime_handle(input.thread_store) else { + return; + }; + if !runtime.is_enabled() { + return; + } - let accounting = runtime.accounting_state(); - accounting.start_turn( - input.turn_id, - input.collaboration_mode.mode, - input.token_usage_at_turn_start, - ); - if matches!( - input.collaboration_mode.mode, - codex_protocol::config_types::ModeKind::Plan - ) { - accounting.clear_current_turn_goal(); - return; - } - let Ok(goal) = self - .state_dbs - .thread_goals() - .get_thread_goal(runtime.thread_id()) - .await - else { - return; - }; - if let Some(goal) = goal - && matches!( - goal.status, - codex_state::ThreadGoalStatus::Active - | codex_state::ThreadGoalStatus::BudgetLimited - ) - { - accounting.mark_turn_goal_active(input.turn_id, goal.goal_id); - } + let accounting = runtime.accounting_state(); + accounting.start_turn( + input.turn_id, + input.collaboration_mode.mode, + input.token_usage_at_turn_start, + ); + if matches!( + input.collaboration_mode.mode, + codex_protocol::config_types::ModeKind::Plan + ) { + accounting.clear_current_turn_goal(); + return; + } + let Ok(goal) = self + .state_dbs + .thread_goals() + .get_thread_goal(runtime.thread_id()) + .await + else { + return; + }; + if let Some(goal) = goal + && matches!( + goal.status, + codex_state::ThreadGoalStatus::Active + | codex_state::ThreadGoalStatus::BudgetLimited + ) + { + accounting.mark_turn_goal_active(input.turn_id, goal.goal_id); + } + }) } - async fn on_turn_stop(&self, input: TurnStopInput<'_>) { - let Some(runtime) = goal_runtime_handle(input.thread_store) else { - return; - }; - if !runtime.is_enabled() { - return; - } + fn on_turn_stop<'a>(&'a self, input: TurnStopInput<'a>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let Some(runtime) = goal_runtime_handle(input.thread_store) else { + return; + }; + if !runtime.is_enabled() { + return; + } - let turn_id = input.turn_store.level_id(); - if let Err(err) = runtime - .account_active_goal_progress( - turn_id, - &format!("{turn_id}:turn-stop"), - codex_state::GoalAccountingMode::ActiveOnly, - BudgetLimitedGoalDisposition::ClearActive, - ) - .await - { - tracing::warn!( - "failed to account active goal progress at turn stop for {turn_id}: {err}" - ); - return; - } - runtime.accounting_state().finish_turn(turn_id); + let turn_id = input.turn_store.level_id(); + if let Err(err) = runtime + .account_active_goal_progress( + turn_id, + &format!("{turn_id}:turn-stop"), + codex_state::GoalAccountingMode::ActiveOnly, + BudgetLimitedGoalDisposition::ClearActive, + ) + .await + { + tracing::warn!( + "failed to account active goal progress at turn stop for {turn_id}: {err}" + ); + return; + } + runtime.accounting_state().finish_turn(turn_id); + }) } - async fn on_turn_abort(&self, input: TurnAbortInput<'_>) { - let Some(runtime) = goal_runtime_handle(input.thread_store) else { - return; - }; - if !runtime.is_enabled() { - return; - } + fn on_turn_abort<'a>(&'a self, input: TurnAbortInput<'a>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let Some(runtime) = goal_runtime_handle(input.thread_store) else { + return; + }; + if !runtime.is_enabled() { + return; + } - let turn_id = input.turn_store.level_id(); - if let Err(err) = runtime - .account_active_goal_progress( - turn_id, - &format!("{turn_id}:turn-abort"), - codex_state::GoalAccountingMode::ActiveOnly, - BudgetLimitedGoalDisposition::ClearActive, - ) - .await - { - tracing::warn!( - "failed to account active goal progress after turn abort for {turn_id}: {err}" - ); - return; - } - runtime.accounting_state().finish_turn(turn_id); + let turn_id = input.turn_store.level_id(); + if let Err(err) = runtime + .account_active_goal_progress( + turn_id, + &format!("{turn_id}:turn-abort"), + codex_state::GoalAccountingMode::ActiveOnly, + BudgetLimitedGoalDisposition::ClearActive, + ) + .await + { + tracing::warn!( + "failed to account active goal progress after turn abort for {turn_id}: {err}" + ); + return; + } + runtime.accounting_state().finish_turn(turn_id); + }) } - async fn on_turn_error(&self, input: TurnErrorInput<'_>) { - let Some(runtime) = goal_runtime_handle(input.thread_store) else { - return; - }; + fn on_turn_error<'a>(&'a self, input: TurnErrorInput<'a>) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let Some(runtime) = goal_runtime_handle(input.thread_store) else { + return; + }; - let reason = match input.error { - CodexErrorInfo::UsageLimitExceeded => ActiveGoalStopReason::UsageLimit, - // The turn has ended because the error was non-retryable or its - // retries were exhausted. Block the goal to prevent automatic - // continuation from looping and consuming tokens, as can happen - // with compaction errors. - _ => ActiveGoalStopReason::TurnError, - }; - if let Err(err) = runtime - .stop_active_goal_for_turn(input.turn_id, reason) - .await - { - tracing::warn!( - error = ?input.error, - "failed to stop active goal after turn error: {err}" - ); - } + let reason = match input.error { + CodexErrorInfo::UsageLimitExceeded => ActiveGoalStopReason::UsageLimit, + // The turn has ended because the error was non-retryable or its + // retries were exhausted. Block the goal to prevent automatic + // continuation from looping and consuming tokens, as can happen + // with compaction errors. + _ => ActiveGoalStopReason::TurnError, + }; + if let Err(err) = runtime + .stop_active_goal_for_turn(input.turn_id, reason) + .await + { + tracing::warn!( + error = ?input.error, + "failed to stop active goal after turn error: {err}" + ); + } + }) } } -#[async_trait] impl TokenUsageContributor for GoalExtension where C: Send + Sync + 'static, { - async fn on_token_usage( - &self, - _session_store: &ExtensionData, - thread_store: &ExtensionData, - turn_store: &ExtensionData, - token_usage: &TokenUsageInfo, - ) { - let Some(runtime) = goal_runtime_handle(thread_store) else { - return; - }; - if !runtime.is_enabled() { - return; - } + fn on_token_usage<'a>( + &'a self, + _session_store: &'a ExtensionData, + thread_store: &'a ExtensionData, + turn_store: &'a ExtensionData, + token_usage: &'a TokenUsageInfo, + ) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let Some(runtime) = goal_runtime_handle(thread_store) else { + return; + }; + if !runtime.is_enabled() { + return; + } - let Some(_recorded) = runtime - .accounting_state() - .record_token_usage(turn_store.level_id(), &token_usage.total_token_usage) - else { - return; - }; + let Some(_recorded) = runtime + .accounting_state() + .record_token_usage(turn_store.level_id(), &token_usage.total_token_usage) + else { + return; + }; + }) } } diff --git a/codex-rs/ext/guardian/Cargo.toml b/codex-rs/ext/guardian/Cargo.toml index 53e553ec1..513254b7c 100644 --- a/codex-rs/ext/guardian/Cargo.toml +++ b/codex-rs/ext/guardian/Cargo.toml @@ -14,7 +14,6 @@ doctest = false workspace = true [dependencies] -async-trait = { workspace = true } codex-core = { workspace = true } codex-extension-api = { workspace = true } codex-protocol = { workspace = true } diff --git a/codex-rs/ext/guardian/src/lib.rs b/codex-rs/ext/guardian/src/lib.rs index 0591887c2..a64cf4fed 100644 --- a/codex-rs/ext/guardian/src/lib.rs +++ b/codex-rs/ext/guardian/src/lib.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use codex_core::config::Config; use codex_extension_api::AgentSpawnFuture; use codex_extension_api::AgentSpawner; +use codex_extension_api::ExtensionFuture; use codex_extension_api::ExtensionRegistryBuilder; use codex_extension_api::ThreadLifecycleContributor; use codex_extension_api::ThreadStartInput; @@ -47,18 +48,23 @@ impl GuardianThreadContext { } } -#[async_trait::async_trait] impl ThreadLifecycleContributor for GuardianExtension where S: Send + Sync, { - async fn on_thread_start(&self, input: ThreadStartInput<'_, Config>) { - let Ok(forked_from_thread_id) = ThreadId::from_string(input.thread_store.level_id()) else { - return; - }; - input.thread_store.insert(GuardianThreadContext { - forked_from_thread_id, - }); + fn on_thread_start<'a>( + &'a self, + input: ThreadStartInput<'a, Config>, + ) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let Ok(forked_from_thread_id) = ThreadId::from_string(input.thread_store.level_id()) + else { + return; + }; + input.thread_store.insert(GuardianThreadContext { + forked_from_thread_id, + }); + }) } } diff --git a/codex-rs/ext/image-generation/src/extension.rs b/codex-rs/ext/image-generation/src/extension.rs index 2c68f614f..6a0016c03 100644 --- a/codex-rs/ext/image-generation/src/extension.rs +++ b/codex-rs/ext/image-generation/src/extension.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use codex_core::config::Config; use codex_extension_api::ConfigContributor; use codex_extension_api::ExtensionData; +use codex_extension_api::ExtensionFuture; use codex_extension_api::ExtensionRegistryBuilder; use codex_extension_api::ThreadLifecycleContributor; use codex_extension_api::ThreadStartInput; @@ -41,13 +42,17 @@ impl From<&Config> for ImageGenerationExtensionConfig { } } -#[async_trait::async_trait] impl ThreadLifecycleContributor for ImageGenerationExtension { /// Seeds image-generation availability when a thread begins. - async fn on_thread_start(&self, input: ThreadStartInput<'_, Config>) { - input - .thread_store - .insert(ImageGenerationExtensionConfig::from(input.config)); + fn on_thread_start<'a>( + &'a self, + input: ThreadStartInput<'a, Config>, + ) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + input + .thread_store + .insert(ImageGenerationExtensionConfig::from(input.config)); + }) } } diff --git a/codex-rs/ext/mcp/Cargo.toml b/codex-rs/ext/mcp/Cargo.toml index 66c9ad172..2d0e508b5 100644 --- a/codex-rs/ext/mcp/Cargo.toml +++ b/codex-rs/ext/mcp/Cargo.toml @@ -14,7 +14,6 @@ doctest = false workspace = true [dependencies] -async-trait = { workspace = true } codex-core = { workspace = true } codex-extension-api = { workspace = true } codex-features = { workspace = true } diff --git a/codex-rs/ext/mcp/src/lib.rs b/codex-rs/ext/mcp/src/lib.rs index 8fad361a4..8568d5815 100644 --- a/codex-rs/ext/mcp/src/lib.rs +++ b/codex-rs/ext/mcp/src/lib.rs @@ -1,4 +1,5 @@ use codex_core::config::Config; +use codex_extension_api::ExtensionFuture; use codex_extension_api::ExtensionRegistryBuilder; use codex_extension_api::McpServerContribution; use codex_extension_api::McpServerContributor; @@ -7,21 +8,25 @@ use codex_mcp::hosted_plugin_runtime_mcp_server_config; struct HostedPluginRuntimeExtension; -#[async_trait::async_trait] impl McpServerContributor for HostedPluginRuntimeExtension { - async fn contribute(&self, config: &Config) -> Vec { - let name = CODEX_APPS_MCP_SERVER_NAME.to_string(); - if !config.features.enabled(codex_features::Feature::Apps) { - return vec![McpServerContribution::Remove { name }]; - } + fn contribute<'a>( + &'a self, + config: &'a Config, + ) -> ExtensionFuture<'a, Vec> { + Box::pin(async move { + let name = CODEX_APPS_MCP_SERVER_NAME.to_string(); + if !config.features.enabled(codex_features::Feature::Apps) { + return vec![McpServerContribution::Remove { name }]; + } - vec![McpServerContribution::Set { - name, - config: Box::new(hosted_plugin_runtime_mcp_server_config( - &config.chatgpt_base_url, - config.apps_mcp_product_sku.as_deref(), - )), - }] + vec![McpServerContribution::Set { + name, + config: Box::new(hosted_plugin_runtime_mcp_server_config( + &config.chatgpt_base_url, + config.apps_mcp_product_sku.as_deref(), + )), + }] + }) } } diff --git a/codex-rs/ext/mcp/tests/hosted_apps_mcp.rs b/codex-rs/ext/mcp/tests/hosted_apps_mcp.rs index bbfa07c14..bb401578f 100644 --- a/codex-rs/ext/mcp/tests/hosted_apps_mcp.rs +++ b/codex-rs/ext/mcp/tests/hosted_apps_mcp.rs @@ -150,11 +150,15 @@ fn installed_manager(config: &Config) -> McpManager { struct RemoveCodexApps; -#[async_trait::async_trait] impl McpServerContributor for RemoveCodexApps { - async fn contribute(&self, _config: &Config) -> Vec { - vec![McpServerContribution::Remove { - name: CODEX_APPS_MCP_SERVER_NAME.to_string(), - }] + fn contribute<'a>( + &'a self, + _config: &'a Config, + ) -> codex_extension_api::ExtensionFuture<'a, Vec> { + Box::pin(async move { + vec![McpServerContribution::Remove { + name: CODEX_APPS_MCP_SERVER_NAME.to_string(), + }] + }) } } diff --git a/codex-rs/ext/memories/src/extension.rs b/codex-rs/ext/memories/src/extension.rs index 0be773b4c..23c80a877 100644 --- a/codex-rs/ext/memories/src/extension.rs +++ b/codex-rs/ext/memories/src/extension.rs @@ -4,6 +4,7 @@ use codex_core::config::Config; use codex_extension_api::ConfigContributor; use codex_extension_api::ContextContributor; use codex_extension_api::ExtensionData; +use codex_extension_api::ExtensionFuture; use codex_extension_api::ExtensionRegistryBuilder; use codex_extension_api::PromptFragment; use codex_extension_api::ThreadLifecycleContributor; @@ -69,12 +70,16 @@ impl ContextContributor for MemoriesExtension { } } -#[async_trait::async_trait] impl ThreadLifecycleContributor for MemoriesExtension { - async fn on_thread_start(&self, input: ThreadStartInput<'_, Config>) { - input - .thread_store - .insert(MemoriesExtensionConfig::from_config(input.config)); + fn on_thread_start<'a>( + &'a self, + input: ThreadStartInput<'a, Config>, + ) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + input + .thread_store + .insert(MemoriesExtensionConfig::from_config(input.config)); + }) } } diff --git a/codex-rs/ext/skills/Cargo.toml b/codex-rs/ext/skills/Cargo.toml index 9ff42cfaf..c67d4b0e1 100644 --- a/codex-rs/ext/skills/Cargo.toml +++ b/codex-rs/ext/skills/Cargo.toml @@ -14,7 +14,6 @@ doctest = false workspace = true [dependencies] -async-trait = { workspace = true } codex-core = { workspace = true } codex-core-skills = { workspace = true } codex-exec-server = { workspace = true } @@ -24,5 +23,6 @@ codex-utils-absolute-path = { workspace = true } codex-utils-string = { workspace = true } [dev-dependencies] +async-trait = { workspace = true } pretty_assertions = { workspace = true } tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } diff --git a/codex-rs/ext/skills/src/extension.rs b/codex-rs/ext/skills/src/extension.rs index 45d7c95cb..e876c66b2 100644 --- a/codex-rs/ext/skills/src/extension.rs +++ b/codex-rs/ext/skills/src/extension.rs @@ -10,6 +10,7 @@ use codex_extension_api::ContextContributor; use codex_extension_api::ContextualUserFragment; use codex_extension_api::ExtensionData; use codex_extension_api::ExtensionEventSink; +use codex_extension_api::ExtensionFuture; use codex_extension_api::ExtensionRegistryBuilder; use codex_extension_api::PromptFragment; use codex_extension_api::ThreadLifecycleContributor; @@ -44,18 +45,22 @@ struct SkillsExtension { event_sink: Arc, } -#[async_trait::async_trait] impl ThreadLifecycleContributor for SkillsExtension { - async fn on_thread_start(&self, input: ThreadStartInput<'_, Config>) { - let selected_roots = input - .thread_store - .get::>() - .map(|selected_roots| selected_roots.as_ref().clone()) - .unwrap_or_default(); - input.thread_store.insert(SkillsThreadState::new( - SkillsExtensionConfig::from_config(input.config), - selected_roots, - )); + fn on_thread_start<'a>( + &'a self, + input: ThreadStartInput<'a, Config>, + ) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + let selected_roots = input + .thread_store + .get::>() + .map(|selected_roots| selected_roots.as_ref().clone()) + .unwrap_or_default(); + input.thread_store.insert(SkillsThreadState::new( + SkillsExtensionConfig::from_config(input.config), + selected_roots, + )); + }) } } @@ -112,112 +117,117 @@ impl ContextContributor for SkillsExtension { } } -#[async_trait::async_trait] impl TurnInputContributor for SkillsExtension { - async fn contribute( - &self, + fn contribute<'a>( + &'a self, input: TurnInputContext, - _session_store: &ExtensionData, - thread_store: &ExtensionData, - turn_store: &ExtensionData, - ) -> Vec> { - let Some(thread_state) = thread_store.get::() else { - return Vec::new(); - }; + _session_store: &'a ExtensionData, + thread_store: &'a ExtensionData, + turn_store: &'a ExtensionData, + ) -> ExtensionFuture<'a, Vec>> { + Box::pin(async move { + let Some(thread_state) = thread_store.get::() else { + return Vec::new(); + }; - let config = thread_state.config(); - let host_loaded_skills = turn_store.get::(); - let query = SkillListQuery { - turn_id: input.turn_id.clone(), - executor_roots: thread_state.selected_roots().to_vec(), - host: host_loaded_skills.clone(), - include_host_skills: true, - include_bundled_skills: config.bundled_skills_enabled, - include_remote_skills: true, - }; - let catalog = self.providers.list_for_turn(query).await; - for warning in &catalog.warnings { - self.emit_warning(&input.turn_id, warning.clone()); - } - - let selected_entries = collect_explicit_skill_mentions(&input.user_input, &catalog); - let mut fragments: Vec> = Vec::new(); - if config.include_instructions { - let mut turn_catalog = catalog.clone(); - turn_catalog - .entries - .retain(|entry| entry.authority.kind != SkillSourceKind::Executor); - if let Some(fragment) = available_skills_fragment(&turn_catalog) { - fragments.push(Box::new(fragment)); + let config = thread_state.config(); + let host_loaded_skills = turn_store.get::(); + let query = SkillListQuery { + turn_id: input.turn_id.clone(), + executor_roots: thread_state.selected_roots().to_vec(), + host: host_loaded_skills.clone(), + include_host_skills: true, + include_bundled_skills: config.bundled_skills_enabled, + include_remote_skills: true, + }; + let catalog = self.providers.list_for_turn(query).await; + for warning in &catalog.warnings { + self.emit_warning(&input.turn_id, warning.clone()); } - } - let mut warnings = catalog.warnings.clone(); - let mut main_prompts_injected = false; - let mut injected_host_skill_prompts = InjectedHostSkillPrompts::default(); - for entry in &selected_entries { - match self - .read_main_prompt(entry, host_loaded_skills.clone()) - .await - { - Ok(read_result) => { - let (contents, truncated) = - truncate_main_prompt_contents(read_result.contents.as_str()); - if truncated { - let warning = format!( - "Skill `{}` exceeded the main prompt context limit and was truncated.", - entry.name - ); + let selected_entries = collect_explicit_skill_mentions(&input.user_input, &catalog); + let mut fragments: Vec> = Vec::new(); + if config.include_instructions { + let mut turn_catalog = catalog.clone(); + turn_catalog + .entries + .retain(|entry| entry.authority.kind != SkillSourceKind::Executor); + if let Some(fragment) = available_skills_fragment(&turn_catalog) { + fragments.push(Box::new(fragment)); + } + } + + let mut warnings = catalog.warnings.clone(); + let mut main_prompts_injected = false; + let mut injected_host_skill_prompts = InjectedHostSkillPrompts::default(); + for entry in &selected_entries { + match self + .read_main_prompt(entry, host_loaded_skills.clone()) + .await + { + Ok(read_result) => { + let (contents, truncated) = + truncate_main_prompt_contents(read_result.contents.as_str()); + if truncated { + let warning = format!( + "Skill `{}` exceeded the main prompt context limit and was truncated.", + entry.name + ); + self.emit_warning(&input.turn_id, warning.clone()); + warnings.push(warning); + } + let injection = SkillInjection { + name: truncate_utf8_to_bytes(&entry.name, MAX_SKILL_NAME_BYTES).0, + path: truncate_utf8_to_bytes( + entry.rendered_path(), + MAX_SKILL_PATH_BYTES, + ) + .0, + contents, + }; + fragments.push(Box::new(SkillInstructions::from(&injection))); + main_prompts_injected = true; + if entry.authority.kind == SkillSourceKind::Host { + injected_host_skill_prompts.insert_path(entry.main_prompt.as_str()); + } + } + Err(message) => { + let warning = format!("Failed to load skill `{}`: {message}", entry.name); self.emit_warning(&input.turn_id, warning.clone()); warnings.push(warning); } - let injection = SkillInjection { - name: truncate_utf8_to_bytes(&entry.name, MAX_SKILL_NAME_BYTES).0, - path: truncate_utf8_to_bytes(entry.rendered_path(), MAX_SKILL_PATH_BYTES).0, - contents, - }; - fragments.push(Box::new(SkillInstructions::from(&injection))); - main_prompts_injected = true; - if entry.authority.kind == SkillSourceKind::Host { - injected_host_skill_prompts.insert_path(entry.main_prompt.as_str()); + } + } + + if let Some(host_loaded_skills) = &host_loaded_skills { + for entry in selected_entries + .iter() + .filter(|entry| entry.authority.kind != SkillSourceKind::Host) + { + for host_skill in host_loaded_skills + .outcome() + .skills + .iter() + .filter(|host_skill| host_skill.name == entry.name) + { + injected_host_skill_prompts + .insert_path(host_skill.path_to_skills_md.to_string_lossy()); } } - Err(message) => { - let warning = format!("Failed to load skill `{}`: {message}", entry.name); - self.emit_warning(&input.turn_id, warning.clone()); - warnings.push(warning); - } } - } - if let Some(host_loaded_skills) = &host_loaded_skills { - for entry in selected_entries - .iter() - .filter(|entry| entry.authority.kind != SkillSourceKind::Host) - { - for host_skill in host_loaded_skills - .outcome() - .skills - .iter() - .filter(|host_skill| host_skill.name == entry.name) - { - injected_host_skill_prompts - .insert_path(host_skill.path_to_skills_md.to_string_lossy()); - } + turn_store.insert(SkillsTurnState { + catalog, + selected_entries, + warnings, + main_prompts_injected, + }); + if !injected_host_skill_prompts.is_empty() { + turn_store.insert(injected_host_skill_prompts); } - } - turn_store.insert(SkillsTurnState { - catalog, - selected_entries, - warnings, - main_prompts_injected, - }); - if !injected_host_skill_prompts.is_empty() { - turn_store.insert(injected_host_skill_prompts); - } - - fragments + fragments + }) } } diff --git a/codex-rs/ext/web-search/src/extension.rs b/codex-rs/ext/web-search/src/extension.rs index d081d4fb2..688b504de 100644 --- a/codex-rs/ext/web-search/src/extension.rs +++ b/codex-rs/ext/web-search/src/extension.rs @@ -9,6 +9,7 @@ use codex_api::SearchSettings; use codex_core::config::Config; use codex_extension_api::ConfigContributor; use codex_extension_api::ExtensionData; +use codex_extension_api::ExtensionFuture; use codex_extension_api::ExtensionRegistryBuilder; use codex_extension_api::ThreadLifecycleContributor; use codex_extension_api::ThreadStartInput; @@ -80,12 +81,16 @@ fn search_settings(config: &Config, web_search_mode: WebSearchMode) -> SearchSet } } -#[async_trait::async_trait] impl ThreadLifecycleContributor for WebSearchExtension { - async fn on_thread_start(&self, input: ThreadStartInput<'_, Config>) { - input - .thread_store - .insert(WebSearchExtensionConfig::from(input.config)); + fn on_thread_start<'a>( + &'a self, + input: ThreadStartInput<'a, Config>, + ) -> ExtensionFuture<'a, ()> { + Box::pin(async move { + input + .thread_store + .insert(WebSearchExtensionConfig::from(input.config)); + }) } }