diff --git a/codex-rs/ext/goal/Cargo.toml b/codex-rs/ext/goal/Cargo.toml index 667ed5d14..611638324 100644 --- a/codex-rs/ext/goal/Cargo.toml +++ b/codex-rs/ext/goal/Cargo.toml @@ -24,6 +24,7 @@ codex-tools = { workspace = true } codex-utils-template = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } +tokio = { workspace = true, features = ["sync"] } tracing = { workspace = true } [dev-dependencies] diff --git a/codex-rs/ext/goal/src/accounting.rs b/codex-rs/ext/goal/src/accounting.rs index e0a84f5f8..db7766177 100644 --- a/codex-rs/ext/goal/src/accounting.rs +++ b/codex-rs/ext/goal/src/accounting.rs @@ -6,10 +6,13 @@ use std::sync::Mutex; use std::sync::PoisonError; use std::time::Duration; use std::time::Instant; +use tokio::sync::Semaphore; +use tokio::sync::SemaphorePermit; -#[derive(Debug, Default)] +#[derive(Debug)] pub(crate) struct GoalAccountingState { inner: Mutex, + progress_accounting_lock: Semaphore, } #[derive(Debug)] @@ -83,6 +86,17 @@ impl GoalAccountingState { self.inner().current_turn_id.clone() } + /// Acquires the per-thread progress-accounting permit. + /// + /// Hold the returned permit from before taking a progress snapshot until after the persistent + /// usage write has succeeded and the snapshot has been marked accounted. This serializes + /// concurrent tool-completion hooks so only one hook can charge a given token or time delta. + pub(crate) async fn progress_accounting_permit( + &self, + ) -> Result, tokio::sync::AcquireError> { + self.progress_accounting_lock.acquire().await + } + pub(crate) fn turn_is_current_active_goal(&self, turn_id: &str) -> bool { let inner = self.inner(); if inner.current_turn_id.as_deref() != Some(turn_id) { @@ -287,6 +301,15 @@ impl GoalAccountingState { } } +impl Default for GoalAccountingState { + fn default() -> Self { + Self { + inner: Mutex::new(GoalAccountingInner::default()), + progress_accounting_lock: Semaphore::new(/*permits*/ 1), + } + } +} + fn token_delta_since_last_accounting(last: &TokenUsage, current: &TokenUsage) -> i64 { let delta = TokenUsage { input_tokens: current.input_tokens.saturating_sub(last.input_tokens), diff --git a/codex-rs/ext/goal/src/runtime.rs b/codex-rs/ext/goal/src/runtime.rs index 8d79391c9..de818dabe 100644 --- a/codex-rs/ext/goal/src/runtime.rs +++ b/codex-rs/ext/goal/src/runtime.rs @@ -348,6 +348,10 @@ impl GoalRuntimeHandle { budget_limited_goal_disposition: BudgetLimitedGoalDisposition, ) -> Result, String> { let accounting = self.accounting_state(); + let _accounting_permit = accounting + .progress_accounting_permit() + .await + .map_err(|err| err.to_string())?; let Some(snapshot) = accounting.progress_snapshot(turn_id) else { return Ok(None); }; @@ -398,6 +402,10 @@ impl GoalRuntimeHandle { budget_limited_goal_disposition: BudgetLimitedGoalDisposition, ) -> Result, String> { let accounting = self.accounting_state(); + let _accounting_permit = accounting + .progress_accounting_permit() + .await + .map_err(|err| err.to_string())?; let Some(snapshot) = accounting.idle_progress_snapshot() else { return Ok(None); }; diff --git a/codex-rs/ext/goal/src/tool.rs b/codex-rs/ext/goal/src/tool.rs index 45ef3c54d..17947aa34 100644 --- a/codex-rs/ext/goal/src/tool.rs +++ b/codex-rs/ext/goal/src/tool.rs @@ -291,6 +291,15 @@ impl GoalToolExecutor { let Some(turn_id) = self.accounting_state.current_turn_id() else { return Ok(None); }; + let _accounting_permit = self + .accounting_state + .progress_accounting_permit() + .await + .map_err(|err| { + FunctionCallError::Fatal(format!( + "goal progress accounting semaphore closed: {err}" + )) + })?; let Some(snapshot) = self.accounting_state.progress_snapshot(turn_id.as_str()) else { return Ok(None); }; diff --git a/codex-rs/ext/goal/tests/goal_extension_backend.rs b/codex-rs/ext/goal/tests/goal_extension_backend.rs index 16b8cf261..3c437a088 100644 --- a/codex-rs/ext/goal/tests/goal_extension_backend.rs +++ b/codex-rs/ext/goal/tests/goal_extension_backend.rs @@ -269,6 +269,69 @@ async fn tool_finish_accounts_active_goal_progress_and_emits_event() -> anyhow:: Ok(()) } +#[tokio::test] +async fn parallel_tool_finish_accounts_active_goal_progress_once() -> anyhow::Result<()> { + let runtime = test_runtime().await?; + let thread_id = test_thread_id()?; + seed_thread_metadata(runtime.as_ref(), thread_id).await?; + let harness = GoalExtensionHarness::new(runtime.clone(), thread_id).await?; + harness + .start_turn( + "turn-1", + &token_usage( + /*input_tokens*/ 100, /*cached_input_tokens*/ 0, + /*output_tokens*/ 0, /*reasoning_output_tokens*/ 0, + /*total_tokens*/ 100, + ), + ) + .await; + + let tools = harness.tools(); + let create_tool = tool_by_name(&tools, "create_goal"); + create_tool + .handle(tool_call( + "create_goal", + "call-create-goal", + json!({ "objective": "ship goal extension backend" }), + )) + .await?; + harness.sink.clear(); + + harness + .record_token_usage( + "turn-1", + &token_usage( + /*input_tokens*/ 130, /*cached_input_tokens*/ 0, + /*output_tokens*/ 0, /*reasoning_output_tokens*/ 0, + /*total_tokens*/ 130, + ), + ) + .await; + + tokio::join!( + harness.notify_tool_finish("turn-1", "call-shell-1", "shell"), + harness.notify_tool_finish("turn-1", "call-shell-2", "shell"), + ); + + let goal = runtime + .thread_goals() + .get_thread_goal(thread_id) + .await? + .ok_or_else(|| anyhow::anyhow!("goal should exist"))?; + assert_eq!(30, goal.tokens_used); + + assert_eq!( + vec![CapturedGoalEvent { + event_id: "call-shell-1".to_string(), + turn_id: Some("turn-1".to_string()), + status: ThreadGoalStatus::Active, + tokens_used: 30, + }], + harness.sink.goal_events() + ); + Ok(()) +} + #[tokio::test] async fn budget_limited_goal_keeps_accruing_until_turn_stop() -> anyhow::Result<()> { let runtime = test_runtime().await?;