diff --git a/codex-rs/core/src/session/tests.rs b/codex-rs/core/src/session/tests.rs index 57f1ce908..2bc63b404 100644 --- a/codex-rs/core/src/session/tests.rs +++ b/codex-rs/core/src/session/tests.rs @@ -9153,6 +9153,89 @@ impl SessionTask for CompletingTask { } } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum TerminalEventKind { + TurnComplete, + TurnAborted, +} + +async fn attach_in_memory_thread_store( + session: &mut Session, +) -> Arc { + let store = Arc::new(codex_thread_store::InMemoryThreadStore::default()); + let thread_store: Arc = store.clone(); + let config = session.get_config().await; + let live_thread = LiveThread::create( + Arc::clone(&thread_store), + CreateThreadParams { + session_id: session.session_id(), + thread_id: session.thread_id, + extra_config: None, + forked_from_id: None, + parent_thread_id: None, + source: SessionSource::Exec, + thread_source: None, + originator: "test_originator".to_string(), + base_instructions: BaseInstructions::default(), + dynamic_tools: Vec::new(), + multi_agent_version: None, + initial_window_id: Uuid::now_v7().to_string(), + metadata: ThreadPersistenceMetadata { + cwd: Some(config.cwd.to_path_buf()), + model_provider: config.model_provider_id.clone(), + memory_mode: if config.memories.generate_memories { + ThreadMemoryMode::Enabled + } else { + ThreadMemoryMode::Disabled + }, + }, + }, + ) + .await + .expect("create thread persistence"); + session.services.thread_store = thread_store; + session.services.live_thread = Some(live_thread); + store +} + +async fn wait_for_flush_count( + store: &codex_thread_store::InMemoryThreadStore, + expected_flushes: usize, +) -> codex_thread_store::InMemoryThreadStoreCalls { + timeout(Duration::from_secs(2), async { + loop { + let calls = store.calls().await; + if calls.flush_thread >= expected_flushes { + return calls; + } + sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("store should observe expected flush count") +} + +async fn recv_terminal_event( + rx: &async_channel::Receiver, + expected: TerminalEventKind, +) -> Event { + timeout(Duration::from_secs(2), async { + loop { + let event = rx.recv().await.expect("event"); + match (&event.msg, expected) { + (EventMsg::TurnComplete(_), TerminalEventKind::TurnComplete) + | (EventMsg::TurnAborted(_), TerminalEventKind::TurnAborted) => return event, + (EventMsg::TurnComplete(_) | EventMsg::TurnAborted(_), _) => { + panic!("unexpected terminal event: {:?}", event.msg) + } + _ => {} + } + } + }) + .await + .expect("terminal event should be delivered") +} + #[derive(Clone, Copy)] struct NeverEndingTask { kind: TaskKind, @@ -9310,6 +9393,79 @@ async fn guardian_helper_review_interrupts_after_three_consecutive_denials() { assert_eq!(aborted.reason, TurnAbortReason::Interrupted); } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn turn_complete_flushes_terminal_event_after_delivery() { + let (mut sess, tc, rx) = make_session_and_context_with_rx().await; + let store = attach_in_memory_thread_store( + Arc::get_mut(&mut sess).expect("session should be uniquely owned"), + ) + .await; + + let input = vec![TurnInput::UserInput { + content: vec![UserInput::Text { + text: "complete normally".to_string(), + text_elements: Vec::new(), + }], + client_id: None, + }]; + sess.spawn_task(Arc::clone(&tc), input, CompletingTask) + .await; + + let event = recv_terminal_event(&rx, TerminalEventKind::TurnComplete).await; + assert!(matches!(event.msg, EventMsg::TurnComplete(_))); + // Expected flushes: + // 1. Task-runner flush after the task body finishes, before TurnComplete is emitted. + // 2. Terminal-event flush after TurnComplete is appended. + let calls = wait_for_flush_count(&store, /*expected_flushes*/ 2).await; + assert_eq!(2, calls.flush_thread); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn turn_aborted_flushes_terminal_event_after_delivery() { + let (mut sess, tc, rx) = make_session_and_context_with_rx().await; + let store = attach_in_memory_thread_store( + Arc::get_mut(&mut sess).expect("session should be uniquely owned"), + ) + .await; + + let input = vec![TurnInput::UserInput { + content: vec![UserInput::Text { + text: "interrupt me".to_string(), + text_elements: Vec::new(), + }], + client_id: None, + }]; + sess.spawn_task( + Arc::clone(&tc), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: true, + }, + ) + .await; + + let abort_task = tokio::spawn({ + let sess = Arc::clone(&sess); + async move { + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; + } + }); + + let event = recv_terminal_event(&rx, TerminalEventKind::TurnAborted).await; + match event.msg { + EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason), + other => panic!("unexpected event: {other:?}"), + } + abort_task.await.expect("abort task should finish"); + // Expected flushes: + // 1. Task-runner flush after the task body observes cancellation. + // 2. Interrupted-marker flush before TurnAborted so abort observers can reread it. + // 3. Terminal-event flush after TurnAborted is appended. + let calls = wait_for_flush_count(&store, /*expected_flushes*/ 3).await; + assert_eq!(3, calls.flush_thread); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[test_log::test] async fn abort_regular_task_emits_marker_before_turn_aborted() { diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index cb532445f..c6d9d2990 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -793,10 +793,14 @@ impl Session { false } }; - if !cleared_active_turn { - return; + if cleared_active_turn { + self.emit_thread_idle_lifecycle_if_idle().await; + } + // Regular items were flushed before this terminal event was appended; buffering + // thread writers may not flush it without another explicit barrier. + if let Err(err) = self.flush_rollout().await { + warn!("failed to flush rollout after emitting terminal turn event: {err}"); } - self.emit_thread_idle_lifecycle_if_idle().await; } async fn take_active_turn(&self) -> Option { @@ -896,6 +900,11 @@ impl Session { .lock() .await .clear_turn(&task.turn_context.sub_id); + // Regular items were flushed before this terminal event was appended; buffering + // thread writers may not flush it without another explicit barrier. + if let Err(err) = self.flush_rollout().await { + warn!("failed to flush rollout after emitting terminal turn event: {err}"); + } } }