Unify thread metadata updates above store (#22236)

- make ThreadStore::update_thread_metadata accept a broad range of
metadata patches
- keep ThreadStore::append_items as raw canonical history append (no
metadata side effects)
- in the local store, write these metadata updates to a combination of
sqlite and rollout jsonl files for backwards-compat. It special cases
which fields need to go into jsonl vs sqlite vs whatever, confining the
awkwardness to just this implementation
- in remote stores we can simply persist the metadata directly to a
database, no special casing required.
- move the "implicit metadata updates triggered by appending rollout
items" from the RolloutRecorder (which is local-threadstore-specific) to
the LiveThread layer above the ThreadStore, inside of a private helper
utility called ThreadMetadataSync. LiveThread calls ThreadStore
append_items and update_metadata separately.
- Add a generic update metadata method to ThreadManager that works on
both live threads and "cold" threads
- Call that ThreadManager method from app server code, so app server
doesn't need to worry about whether the thread is live or not
This commit is contained in:
Tom
2026-05-12 17:28:15 -07:00
committed by GitHub
Unverified
parent f11ad1eacb
commit c51c65ad09
31 changed files with 2382 additions and 762 deletions
+1
View File
@@ -3687,6 +3687,7 @@ dependencies = [
"codex-protocol",
"codex-rollout",
"codex-state",
"codex-utils-path",
"pretty_assertions",
"serde",
"serde_json",
@@ -401,7 +401,6 @@ use codex_thread_store::ThreadMetadataPatch as StoreThreadMetadataPatch;
use codex_thread_store::ThreadSortKey as StoreThreadSortKey;
use codex_thread_store::ThreadStore;
use codex_thread_store::ThreadStoreError;
use codex_thread_store::UpdateThreadMetadataParams as StoreUpdateThreadMetadataParams;
use codex_utils_absolute_path::AbsolutePathBuf;
use codex_utils_pty::DEFAULT_OUTPUT_BYTES_CAP;
use std::collections::HashMap;
@@ -323,7 +323,7 @@ impl ExternalAgentConfigRequestProcessor {
.thread
.update_thread_metadata(
ThreadMetadataPatch {
name: Some(name),
name: Some(Some(name)),
..Default::default()
},
/*include_archived*/ false,
@@ -1430,17 +1430,17 @@ impl ThreadRequestProcessor {
};
let _thread_list_state_permit = self.acquire_thread_list_state_permit().await?;
self.thread_store
.update_thread_metadata(StoreUpdateThreadMetadataParams {
self.thread_manager
.update_thread_metadata(
thread_id,
patch: StoreThreadMetadataPatch {
name: Some(name.clone()),
StoreThreadMetadataPatch {
name: Some(Some(name.clone())),
..Default::default()
},
include_archived: false,
})
/*include_archived*/ false,
)
.await
.map_err(|err| thread_store_write_error("set thread name", err))?;
.map_err(|err| core_thread_write_error("set thread name", err))?;
Ok((
ThreadSetNameResponse {},
@@ -1459,33 +1459,17 @@ impl ThreadRequestProcessor {
let thread_id = ThreadId::from_string(&thread_id)
.map_err(|err| invalid_request(format!("invalid thread id: {err}")))?;
if let Ok(thread) = self.thread_manager.get_thread(thread_id).await {
if thread.config_snapshot().await.ephemeral {
return Err(invalid_request(format!(
"ephemeral thread does not support memory mode updates: {thread_id}"
)));
}
thread
.set_thread_memory_mode(mode.to_core())
.await
.map_err(|err| {
internal_error(format!("failed to set thread memory mode: {err}"))
})?;
return Ok(ThreadMemoryModeSetResponse {});
}
self.thread_store
.update_thread_metadata(StoreUpdateThreadMetadataParams {
self.thread_manager
.update_thread_metadata(
thread_id,
patch: StoreThreadMetadataPatch {
StoreThreadMetadataPatch {
memory_mode: Some(mode.to_core()),
..Default::default()
},
include_archived: false,
})
/*include_archived*/ false,
)
.await
.map_err(|err| thread_store_write_error("set thread memory mode", err))?;
.map_err(|err| core_thread_write_error("set thread memory mode", err))?;
Ok(ThreadMemoryModeSetResponse {})
}
@@ -1551,35 +1535,19 @@ impl ThreadRequestProcessor {
..Default::default()
};
let loaded_thread = self.thread_manager.get_thread(thread_uuid).await.ok();
let updated_thread = {
let _thread_list_state_permit = self.acquire_thread_list_state_permit().await?;
if let Some(loaded_thread) = loaded_thread.as_ref() {
if loaded_thread.config_snapshot().await.ephemeral {
return Err(invalid_request(format!(
"ephemeral thread does not support metadata updates: {thread_id}"
)));
}
loaded_thread
.update_thread_metadata(patch, /*include_archived*/ true)
.await
} else {
self.thread_store
.update_thread_metadata(StoreUpdateThreadMetadataParams {
thread_id: thread_uuid,
patch,
include_archived: true,
})
.await
}
.map_err(|err| thread_store_write_error("update thread metadata", err))?
self.thread_manager
.update_thread_metadata(thread_uuid, patch, /*include_archived*/ true)
.await
.map_err(|err| core_thread_write_error("update thread metadata", err))?
};
let (mut thread, _) = thread_from_stored_thread(
updated_thread,
self.config.model_provider_id.as_str(),
&self.config.cwd,
);
if let Some(loaded_thread) = loaded_thread.as_ref() {
if let Ok(loaded_thread) = self.thread_manager.get_thread(thread_uuid).await {
thread.session_id = loaded_thread.session_configured().session_id.to_string();
}
self.attach_thread_name(thread_uuid, &mut thread).await;
@@ -3707,15 +3675,13 @@ fn conversation_summary_rollout_path_read_error(
}
}
fn thread_store_write_error(operation: &str, err: ThreadStoreError) -> JSONRPCErrorError {
fn core_thread_write_error(operation: &str, err: CodexErr) -> JSONRPCErrorError {
match err {
ThreadStoreError::ThreadNotFound { thread_id } => {
CodexErr::ThreadNotFound(thread_id) => {
invalid_request(format!("thread not found: {thread_id}"))
}
ThreadStoreError::InvalidRequest { message } => invalid_request(message),
ThreadStoreError::Unsupported { operation } => {
unsupported_thread_store_operation(operation)
}
CodexErr::InvalidRequest(message) => invalid_request(message),
CodexErr::UnsupportedOperation(message) => method_not_found(message),
err => internal_error(format!("failed to {operation}: {err}")),
}
}
@@ -1304,7 +1304,7 @@ async fn seed_pathless_store_thread(
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id,
patch: ThreadMetadataPatch {
name: Some("named pathless thread".to_string()),
name: Some(Some("named pathless thread".to_string())),
..Default::default()
},
include_archived: true,
@@ -226,7 +226,7 @@ async fn thread_unarchive_preserves_pathless_store_metadata() -> Result<()> {
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id,
patch: ThreadMetadataPatch {
name: Some("named pathless thread".to_string()),
name: Some(Some("named pathless thread".to_string())),
..Default::default()
},
include_archived: true,
+16 -3
View File
@@ -2846,12 +2846,21 @@ impl Session {
&self,
turn_context: &TurnContext,
token_usage: Option<&TokenUsage>,
) {
self.record_token_usage_info(turn_context, token_usage)
.await;
self.send_token_count_event(turn_context).await;
}
pub(crate) async fn record_token_usage_info(
&self,
turn_context: &TurnContext,
token_usage: Option<&TokenUsage>,
) {
if let Some(token_usage) = token_usage {
let mut state = self.state.lock().await;
state.update_token_info_from_usage(token_usage, turn_context.model_context_window());
}
self.send_token_count_event(turn_context).await;
}
pub(crate) async fn recompute_token_usage(&self, turn_context: &TurnContext) {
@@ -2892,11 +2901,15 @@ impl Session {
turn_context: &TurnContext,
new_rate_limits: RateLimitSnapshot,
) {
self.record_rate_limits_info(new_rate_limits).await;
self.send_token_count_event(turn_context).await;
}
pub(crate) async fn record_rate_limits_info(&self, new_rate_limits: RateLimitSnapshot) {
{
let mut state = self.state.lock().await;
state.set_rate_limits(new_rate_limits);
}
self.send_token_count_event(turn_context).await;
}
pub(crate) async fn mcp_dependency_prompted(&self) -> HashSet<String> {
@@ -2927,7 +2940,7 @@ impl Session {
state.set_server_reasoning_included(included);
}
async fn send_token_count_event(&self, turn_context: &TurnContext) {
pub(crate) async fn send_token_count_event(&self, turn_context: &TurnContext) {
let (info, rate_limits) = {
let state = self.state.lock().await;
state.token_info_and_rate_limits()
+13 -2
View File
@@ -1872,6 +1872,7 @@ async fn try_run_sampling_request(
Box<dyn ToolArgumentDiffConsumer>,
)> = None;
let mut should_emit_turn_diff = false;
let mut should_emit_token_count = false;
let reasoning_effort = turn_context.effective_reasoning_effort_for_tracing();
let plan_mode = turn_context.collaboration_mode.mode == ModeKind::Plan;
let mut assistant_message_stream_parsers = AssistantMessageStreamParsers::new(plan_mode);
@@ -2098,7 +2099,8 @@ async fn try_run_sampling_request(
ResponseEvent::RateLimits(snapshot) => {
// Update internal state with latest rate limits, but defer sending until
// token usage is available to avoid duplicate TokenCount events.
sess.update_rate_limits(&turn_context, snapshot).await;
sess.record_rate_limits_info(snapshot).await;
should_emit_token_count = true;
}
ResponseEvent::ModelsEtag(etag) => {
// Update internal state with latest models etag
@@ -2116,8 +2118,9 @@ async fn try_run_sampling_request(
&mut assistant_message_stream_parsers,
)
.await;
sess.update_token_usage_info(&turn_context, token_usage.as_ref())
sess.record_token_usage_info(&turn_context, token_usage.as_ref())
.await;
should_emit_token_count = true;
should_emit_turn_diff = true;
if let Some(false) = end_turn {
needs_follow_up = true;
@@ -2245,6 +2248,14 @@ async fn try_run_sampling_request(
drain_in_flight(&mut in_flight, sess.clone(), turn_context.clone()).await?;
if should_emit_token_count {
// A tool call such as request_user_input can intentionally pause the turn. Emit token
// counts only after pending tools resolve so clients do not see progress events while the
// turn is waiting on the user. This also needs to happen before returning cancellation so
// token usage already recorded from the completed response is still persisted.
sess.send_token_count_event(&turn_context).await;
}
if cancellation_token.is_cancelled() {
return Err(CodexErr::TurnAborted);
}
+53
View File
@@ -58,8 +58,10 @@ use codex_thread_store::LocalThreadStoreConfig;
use codex_thread_store::ReadThreadByRolloutPathParams;
use codex_thread_store::ReadThreadParams;
use codex_thread_store::StoredThread;
use codex_thread_store::ThreadMetadataPatch;
use codex_thread_store::ThreadStore;
use codex_thread_store::ThreadStoreError;
use codex_thread_store::UpdateThreadMetadataParams;
use codex_utils_absolute_path::AbsolutePathBuf;
use futures::StreamExt;
use futures::stream::FuturesUnordered;
@@ -456,6 +458,44 @@ impl ThreadManager {
self.state.get_thread(thread_id).await
}
/// Updates metadata for loaded and cold threads through one entrypoint.
///
/// Loaded threads route through `CodexThread`/`LiveThread`, so metadata changes stay ordered
/// with live rollout writes. Cold threads go directly to the store, which owns unloaded JSONL
/// compatibility and SQLite metadata updates.
pub async fn update_thread_metadata(
&self,
thread_id: ThreadId,
patch: ThreadMetadataPatch,
include_archived: bool,
) -> CodexResult<StoredThread> {
if let Ok(thread) = self.get_thread(thread_id).await {
if thread.config_snapshot().await.ephemeral {
return Err(CodexErr::InvalidRequest(format!(
"ephemeral thread does not support metadata updates: {thread_id}"
)));
}
return thread
.update_thread_metadata(patch, include_archived)
.await
.map_err(|err| thread_store_metadata_update_error(thread_id, err));
}
self.state
.thread_store
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id,
patch,
include_archived,
})
.await
.map_err(|err| match err {
ThreadStoreError::ThreadNotFound { thread_id } => {
CodexErr::ThreadNotFound(thread_id)
}
err => thread_store_metadata_update_error(thread_id, err),
})
}
/// List `thread_id` plus all known descendants in its spawn subtree.
pub async fn list_agent_subtree_thread_ids(
&self,
@@ -1298,6 +1338,19 @@ fn thread_store_rollout_read_error(err: ThreadStoreError) -> CodexErr {
}
}
fn thread_store_metadata_update_error(thread_id: ThreadId, err: ThreadStoreError) -> CodexErr {
match err {
ThreadStoreError::ThreadNotFound { thread_id } => CodexErr::ThreadNotFound(thread_id),
ThreadStoreError::InvalidRequest { message } => CodexErr::InvalidRequest(message),
ThreadStoreError::Unsupported { operation } => CodexErr::UnsupportedOperation(format!(
"thread metadata update is not supported by this store: {operation}"
)),
err => CodexErr::Fatal(format!(
"failed to update thread metadata {thread_id}: {err}"
)),
}
}
/// Return a fork snapshot cut strictly before the nth user message (0-based).
///
/// Out-of-range values keep the full committed history at a turn boundary, but
-32
View File
@@ -2497,38 +2497,6 @@ async fn token_count_includes_rate_limits_snapshot() {
.await
.unwrap();
let first_token_event =
wait_for_event(&codex, |msg| matches!(msg, EventMsg::TokenCount(_))).await;
let rate_limit_only = match first_token_event {
EventMsg::TokenCount(ev) => ev,
_ => unreachable!(),
};
let rate_limit_json = serde_json::to_value(&rate_limit_only).unwrap();
pretty_assertions::assert_eq!(
rate_limit_json,
json!({
"info": null,
"rate_limits": {
"limit_id": "codex",
"limit_name": null,
"primary": {
"used_percent": 12.5,
"window_minutes": 10,
"resets_at": 1704069000
},
"secondary": {
"used_percent": 40.0,
"window_minutes": 60,
"resets_at": 1704074400
},
"credits": null,
"plan_type": null,
"rate_limit_reached_type": null
}
})
);
let token_event = wait_for_event(
&codex,
|msg| matches!(msg, EventMsg::TokenCount(ev) if ev.info.is_some()),
+1 -1
View File
@@ -891,7 +891,7 @@ async fn handle_response_item_records_tool_result_for_custom_tool_call() {
.await
.unwrap();
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TokenCount(_))).await;
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
logs_assert(|lines: &[&str]| {
let line = lines
@@ -17,6 +17,7 @@ use core_test_support::responses;
use core_test_support::responses::ResponsesRequest;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_completed_with_tokens;
use core_test_support::responses::ev_function_call;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::sse;
@@ -30,6 +31,8 @@ use core_test_support::wait_for_event_match;
use pretty_assertions::assert_eq;
use serde_json::Value;
use serde_json::json;
use tokio::time::Duration;
use tokio::time::timeout;
fn call_output(req: &ResponsesRequest, call_id: &str) -> String {
let raw = req.function_call_output(call_id);
@@ -118,6 +121,7 @@ async fn request_user_input_round_trip_for_mode(mode: ModeKind) -> anyhow::Resul
let first_response = sse(vec![
ev_response_created("resp-1"),
ev_function_call(call_id, "request_user_input", &request_args),
ev_rate_limits(),
ev_completed("resp-1"),
]);
responses::mount_sse_once(&server, first_response).await;
@@ -169,6 +173,22 @@ async fn request_user_input_round_trip_for_mode(mode: ModeKind) -> anyhow::Resul
assert_eq!(request.call_id, call_id);
assert_eq!(request.questions.len(), 1);
assert_eq!(request.questions[0].is_other, true);
assert!(
timeout(Duration::from_millis(200), async {
loop {
let event = match codex.next_event().await {
Ok(event) => event,
Err(err) => panic!("event stream should stay open: {err}"),
};
if matches!(event.msg, EventMsg::TokenCount(_)) {
return;
}
}
})
.await
.is_err(),
"TokenCount should wait until request_user_input resolves"
);
let mut answers = HashMap::new();
answers.insert(
@@ -185,6 +205,7 @@ async fn request_user_input_round_trip_for_mode(mode: ModeKind) -> anyhow::Resul
})
.await?;
wait_for_event(&codex, |event| matches!(event, EventMsg::TokenCount(_))).await;
wait_for_event(&codex, |event| matches!(event, EventMsg::TurnComplete(_))).await;
let req = second_mock.single_request();
@@ -202,6 +223,118 @@ async fn request_user_input_round_trip_for_mode(mode: ModeKind) -> anyhow::Resul
Ok(())
}
fn ev_rate_limits() -> Value {
json!({
"type": "codex.rate_limits",
"plan_type": "plus",
"rate_limits": {
"allowed": true,
"limit_reached": false,
"primary": {
"used_percent": 42,
"window_minutes": 60,
"reset_at": 1700000000
},
"secondary": null
},
"code_review_rate_limits": null,
"credits": null,
"promo": null
})
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn request_user_input_interrupt_emits_deferred_token_count() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let TestCodex {
codex,
cwd,
session_configured,
..
} = test_codex().build(&server).await?;
let call_id = "user-input-interrupt";
let request_args = json!({
"questions": [{
"id": "confirm_path",
"header": "Confirm",
"question": "Proceed with the plan?",
"options": [{
"label": "Yes (Recommended)",
"description": "Continue the current plan."
}, {
"label": "No",
"description": "Stop and revisit the approach."
}]
}]
})
.to_string();
let response = sse(vec![
ev_response_created("resp-interrupt"),
ev_function_call(call_id, "request_user_input", &request_args),
ev_completed_with_tokens("resp-interrupt", /*total_tokens*/ 77),
]);
responses::mount_sse_once(&server, response).await;
let (sandbox_policy, permission_profile) =
turn_permission_fields(PermissionProfile::Disabled, cwd.path());
codex
.submit(Op::UserTurn {
environments: None,
items: vec![UserInput::Text {
text: "please confirm".into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
cwd: cwd.path().to_path_buf(),
approval_policy: AskForApproval::Never,
approvals_reviewer: None,
sandbox_policy,
permission_profile,
model: session_configured.model.clone(),
effort: None,
summary: None,
service_tier: None,
collaboration_mode: Some(CollaborationMode {
mode: ModeKind::Plan,
settings: Settings {
model: session_configured.model,
reasoning_effort: None,
developer_instructions: None,
},
}),
personality: None,
})
.await?;
let request = wait_for_event_match(&codex, |event| match event {
EventMsg::RequestUserInput(request) => Some(request.clone()),
_ => None,
})
.await;
codex.submit(Op::Interrupt).await?;
let token_count = wait_for_event_match(&codex, |event| match event {
EventMsg::TokenCount(token_count) => Some(token_count.clone()),
_ => None,
})
.await;
assert_eq!(
token_count
.info
.map(|info| info.total_token_usage.total_tokens),
Some(77)
);
wait_for_event(&codex, |event| matches!(event, EventMsg::TurnAborted(_))).await;
assert_eq!(request.call_id, call_id);
Ok(())
}
async fn assert_request_user_input_rejected<F>(mode_name: &str, build_mode: F) -> anyhow::Result<()>
where
F: FnOnce(String) -> CollaborationMode,
-4
View File
@@ -109,7 +109,6 @@ async fn resume_includes_initial_messages_from_rollout_events() -> Result<()> {
[
EventMsg::TurnStarted(_),
EventMsg::UserMessage(_),
EventMsg::TokenCount(_),
EventMsg::AgentMessage(_),
EventMsg::TokenCount(_),
EventMsg::TurnComplete(_),
@@ -126,7 +125,6 @@ async fn resume_includes_initial_messages_from_rollout_events() -> Result<()> {
[
EventMsg::TurnStarted(started),
EventMsg::UserMessage(first_user),
EventMsg::TokenCount(_),
EventMsg::AgentMessage(assistant_message),
EventMsg::TokenCount(_),
EventMsg::TurnComplete(completed),
@@ -196,7 +194,6 @@ async fn resume_includes_initial_messages_from_reasoning_events() -> Result<()>
[
EventMsg::TurnStarted(_),
EventMsg::UserMessage(_),
EventMsg::TokenCount(_),
EventMsg::AgentReasoning(_),
EventMsg::AgentReasoningRawContent(_),
EventMsg::AgentMessage(_),
@@ -215,7 +212,6 @@ async fn resume_includes_initial_messages_from_reasoning_events() -> Result<()>
[
EventMsg::TurnStarted(started),
EventMsg::UserMessage(first_user),
EventMsg::TokenCount(_),
EventMsg::AgentReasoning(reasoning),
EventMsg::AgentReasoningRawContent(raw),
EventMsg::AgentMessage(assistant_message),
@@ -4,7 +4,6 @@ use std::path::Path;
use std::path::PathBuf;
use chrono::Utc;
use codex_core::EventPersistenceMode;
use codex_core::RolloutRecorder;
use codex_core::RolloutRecorderParams;
use codex_core::config::ConfigBuilder;
@@ -189,10 +188,7 @@ async fn find_locates_rollout_file_written_by_recorder() -> std::io::Result<()>
/*thread_source*/ None,
BaseInstructions::default(),
Vec::new(),
EventPersistenceMode::Limited,
),
/*state_db_ctx*/ None,
/*state_builder*/ None,
)
.await?;
recorder.persist().await?;
+1
View File
@@ -55,6 +55,7 @@ pub use list::rollout_date_parts;
pub use metadata::builder_from_items;
pub use policy::EventPersistenceMode;
pub use policy::is_persisted_rollout_item;
pub use policy::persisted_rollout_items;
pub use policy::should_persist_response_item_for_memories;
pub use recorder::RolloutRecorder;
pub use recorder::RolloutRecorderParams;
+40
View File
@@ -1,6 +1,9 @@
use crate::protocol::EventMsg;
use crate::protocol::RolloutItem;
use codex_protocol::models::ResponseItem;
use codex_utils_string::truncate_middle_chars;
const PERSISTED_EXEC_AGGREGATED_OUTPUT_MAX_BYTES: usize = 10_000;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum EventPersistenceMode {
@@ -22,6 +25,43 @@ pub fn is_persisted_rollout_item(item: &RolloutItem, mode: EventPersistenceMode)
}
}
/// Return the canonical rollout items that should be persisted for a live append.
pub fn persisted_rollout_items(
items: &[RolloutItem],
mode: EventPersistenceMode,
) -> Vec<RolloutItem> {
let mut persisted = Vec::new();
for item in items {
if is_persisted_rollout_item(item, mode) {
persisted.push(sanitize_rollout_item_for_persistence(item.clone(), mode));
}
}
persisted
}
fn sanitize_rollout_item_for_persistence(
item: RolloutItem,
mode: EventPersistenceMode,
) -> RolloutItem {
if mode != EventPersistenceMode::Extended {
return item;
}
match item {
RolloutItem::EventMsg(EventMsg::ExecCommandEnd(mut event)) => {
event.aggregated_output = truncate_middle_chars(
&event.aggregated_output,
PERSISTED_EXEC_AGGREGATED_OUTPUT_MAX_BYTES,
);
event.stdout.clear();
event.stderr.clear();
event.formatted_output.clear();
RolloutItem::EventMsg(EventMsg::ExecCommandEnd(event))
}
_ => item,
}
}
/// Whether a `ResponseItem` should be persisted in rollout files.
#[inline]
pub fn should_persist_response_item(item: &ResponseItem) -> bool {
+74 -318
View File
@@ -8,16 +8,11 @@ use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use std::time::Instant;
use chrono::DateTime;
use chrono::SecondsFormat;
use chrono::Utc;
use codex_protocol::ThreadId;
use codex_protocol::dynamic_tools::DynamicToolSpec;
use codex_protocol::models::BaseInstructions;
use codex_utils_string::truncate_middle_chars;
use serde_json::Value;
use time::OffsetDateTime;
use time::format_description::FormatItem;
@@ -47,15 +42,13 @@ use super::list::get_threads_in_root;
use super::list::parse_cursor;
use super::list::parse_timestamp_uuid_from_filename;
use super::metadata;
use super::policy::EventPersistenceMode;
use super::policy::is_persisted_rollout_item;
use super::session_index::find_thread_names_by_ids;
use crate::config::RolloutConfigView;
use crate::default_client::originator;
use crate::state_db;
use crate::state_db::StateDbHandle;
use codex_git_utils::collect_git_info;
use codex_protocol::protocol::EventMsg;
use codex_git_utils::get_git_repo_root;
use codex_protocol::protocol::GitInfo as ProtocolGitInfo;
use codex_protocol::protocol::InitialHistory;
use codex_protocol::protocol::ResumedHistory;
@@ -66,11 +59,9 @@ use codex_protocol::protocol::SessionMetaLine;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::ThreadSource;
use codex_state::StateRuntime;
use codex_state::ThreadMetadataBuilder;
use codex_utils_path as path_utils;
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
/// every update.
/// Writes canonical session rollout items to JSONL.
///
/// Rollouts are recorded as JSONL and can be inspected with tools such as:
///
@@ -83,7 +74,6 @@ pub struct RolloutRecorder {
tx: Sender<RolloutCmd>,
writer_task: Arc<RolloutWriterTask>,
pub(crate) rollout_path: PathBuf,
event_persistence_mode: EventPersistenceMode,
}
#[derive(Clone)]
@@ -95,11 +85,9 @@ pub enum RolloutRecorderParams {
thread_source: Option<ThreadSource>,
base_instructions: BaseInstructions,
dynamic_tools: Vec<DynamicToolSpec>,
event_persistence_mode: EventPersistenceMode,
},
Resume {
path: PathBuf,
event_persistence_mode: EventPersistenceMode,
},
}
@@ -172,7 +160,6 @@ impl RolloutRecorderParams {
thread_source: Option<ThreadSource>,
base_instructions: BaseInstructions,
dynamic_tools: Vec<DynamicToolSpec>,
event_persistence_mode: EventPersistenceMode,
) -> Self {
Self::Create {
conversation_id,
@@ -181,42 +168,11 @@ impl RolloutRecorderParams {
thread_source,
base_instructions,
dynamic_tools,
event_persistence_mode,
}
}
pub fn resume(path: PathBuf, event_persistence_mode: EventPersistenceMode) -> Self {
Self::Resume {
path,
event_persistence_mode,
}
}
}
const PERSISTED_EXEC_AGGREGATED_OUTPUT_MAX_BYTES: usize = 10_000;
fn sanitize_rollout_item_for_persistence(
item: RolloutItem,
mode: EventPersistenceMode,
) -> RolloutItem {
if mode != EventPersistenceMode::Extended {
return item;
}
match item {
RolloutItem::EventMsg(EventMsg::ExecCommandEnd(mut event)) => {
// Persist only a bounded aggregated summary of command output.
event.aggregated_output = truncate_middle_chars(
&event.aggregated_output,
PERSISTED_EXEC_AGGREGATED_OUTPUT_MAX_BYTES,
);
// Drop unnecessary fields from rollout storage since aggregated_output is all we need.
event.stdout.clear();
event.stderr.clear();
event.formatted_output.clear();
RolloutItem::EventMsg(EventMsg::ExecCommandEnd(event))
}
_ => item,
pub fn resume(path: PathBuf) -> Self {
Self::Resume { path }
}
}
@@ -691,80 +647,65 @@ impl RolloutRecorder {
pub async fn new(
config: &impl RolloutConfigView,
params: RolloutRecorderParams,
state_db_ctx: Option<StateDbHandle>,
state_builder: Option<ThreadMetadataBuilder>,
) -> std::io::Result<Self> {
let (file, deferred_log_file_info, rollout_path, meta, event_persistence_mode) =
match params {
RolloutRecorderParams::Create {
conversation_id,
let (file, deferred_log_file_info, rollout_path, meta) = match params {
RolloutRecorderParams::Create {
conversation_id,
forked_from_id,
source,
thread_source,
base_instructions,
dynamic_tools,
} => {
let log_file_info = precompute_log_file_info(config, conversation_id)?;
let path = log_file_info.path.clone();
let session_id = log_file_info.conversation_id;
let started_at = log_file_info.timestamp;
let timestamp_format: &[FormatItem] = format_description!(
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
);
let timestamp = started_at
.to_offset(time::UtcOffset::UTC)
.format(timestamp_format)
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
let session_meta = SessionMeta {
id: session_id,
forked_from_id,
timestamp,
cwd: config.cwd().to_path_buf(),
originator: originator().value,
cli_version: env!("CARGO_PKG_VERSION").to_string(),
agent_nickname: source.get_nickname(),
agent_role: source.get_agent_role(),
agent_path: source.get_agent_path().map(Into::into),
source,
thread_source,
base_instructions,
dynamic_tools,
event_persistence_mode,
} => {
let log_file_info = precompute_log_file_info(config, conversation_id)?;
let path = log_file_info.path.clone();
let session_id = log_file_info.conversation_id;
let started_at = log_file_info.timestamp;
model_provider: Some(config.model_provider_id().to_string()),
base_instructions: Some(base_instructions),
dynamic_tools: if dynamic_tools.is_empty() {
None
} else {
Some(dynamic_tools)
},
memory_mode: (!config.generate_memories()).then_some("disabled".to_string()),
};
let timestamp_format: &[FormatItem] = format_description!(
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
);
let timestamp = started_at
.to_offset(time::UtcOffset::UTC)
.format(timestamp_format)
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
let session_meta = SessionMeta {
id: session_id,
forked_from_id,
timestamp,
cwd: config.cwd().to_path_buf(),
originator: originator().value,
cli_version: env!("CARGO_PKG_VERSION").to_string(),
agent_nickname: source.get_nickname(),
agent_role: source.get_agent_role(),
agent_path: source.get_agent_path().map(Into::into),
source,
thread_source,
model_provider: Some(config.model_provider_id().to_string()),
base_instructions: Some(base_instructions),
dynamic_tools: if dynamic_tools.is_empty() {
None
} else {
Some(dynamic_tools)
},
memory_mode: (!config.generate_memories())
.then_some("disabled".to_string()),
};
(
None,
Some(log_file_info),
path,
Some(session_meta),
event_persistence_mode,
)
}
RolloutRecorderParams::Resume {
path,
event_persistence_mode,
} => (
Some(
tokio::fs::OpenOptions::new()
.append(true)
.open(&path)
.await?,
),
None,
path,
None,
event_persistence_mode,
(None, Some(log_file_info), path, Some(session_meta))
}
RolloutRecorderParams::Resume { path } => (
Some(
tokio::fs::OpenOptions::new()
.append(true)
.open(&path)
.await?,
),
};
None,
path,
None,
),
};
// Clone the cwd for the spawned task to collect git info asynchronously
let cwd = config.cwd().to_path_buf();
@@ -779,9 +720,6 @@ impl RolloutRecorder {
let writer_task = Arc::new(RolloutWriterTask::new());
let writer_task_for_spawn = Arc::clone(&writer_task);
let rollout_path_for_spawn = rollout_path.clone();
let default_provider = config.model_provider_id().to_string();
let generate_memories = config.generate_memories();
let state_db_ctx_for_spawn = state_db_ctx.clone();
let handle = tokio::task::spawn(async move {
let result = rollout_writer(
file,
@@ -790,10 +728,6 @@ impl RolloutRecorder {
meta,
cwd,
rollout_path_for_spawn.clone(),
state_db_ctx_for_spawn,
state_builder,
default_provider,
generate_memories,
)
.await;
if let Err(err) = result {
@@ -812,7 +746,6 @@ impl RolloutRecorder {
tx,
writer_task,
rollout_path,
event_persistence_mode,
})
}
@@ -820,24 +753,12 @@ impl RolloutRecorder {
self.rollout_path.as_path()
}
pub async fn record_items(&self, items: &[RolloutItem]) -> std::io::Result<()> {
let mut filtered = Vec::new();
for item in items {
// Note that function calls may look a bit strange if they are
// "fully qualified MCP tool calls," so we could consider
// reformatting them in that case.
if is_persisted_rollout_item(item, self.event_persistence_mode) {
filtered.push(sanitize_rollout_item_for_persistence(
item.clone(),
self.event_persistence_mode,
));
}
}
if filtered.is_empty() {
pub async fn record_canonical_items(&self, items: &[RolloutItem]) -> std::io::Result<()> {
if items.is_empty() {
return Ok(());
}
self.tx
.send(RolloutCmd::AddItems(filtered))
.send(RolloutCmd::AddItems(items.to_vec()))
.await
.map_err(|e| {
self.writer_task.terminal_failure().unwrap_or_else(|| {
@@ -1459,48 +1380,17 @@ struct RolloutWriterState {
meta: Option<SessionMeta>,
cwd: PathBuf,
rollout_path: PathBuf,
state_db_ctx: Option<StateDbHandle>,
state_builder: Option<ThreadMetadataBuilder>,
default_provider: String,
generate_memories: bool,
thread_updated_at_touch: ThreadUpdatedAtTouch,
last_logged_error: Option<String>,
}
#[cfg(not(test))]
const THREAD_UPDATED_AT_TOUCH_INTERVAL: Duration = Duration::from_secs(5);
#[cfg(test)]
const THREAD_UPDATED_AT_TOUCH_INTERVAL: Duration = Duration::from_millis(50);
#[derive(Default)]
struct ThreadUpdatedAtTouch {
last_persisted_at: Option<Instant>,
pending_touch: Option<(ThreadId, DateTime<Utc>)>,
}
impl ThreadUpdatedAtTouch {
fn mark_persisted(&mut self, now: Instant) {
self.last_persisted_at = Some(now);
self.pending_touch = None;
}
}
impl RolloutWriterState {
#[allow(clippy::too_many_arguments)]
fn new(
file: Option<tokio::fs::File>,
deferred_log_file_info: Option<LogFileInfo>,
meta: Option<SessionMeta>,
cwd: PathBuf,
rollout_path: PathBuf,
state_db_ctx: Option<StateDbHandle>,
mut state_builder: Option<ThreadMetadataBuilder>,
default_provider: String,
generate_memories: bool,
) -> Self {
if let Some(builder) = state_builder.as_mut() {
builder.rollout_path = rollout_path.clone();
}
Self {
writer: file.map(|file| JsonlWriter { file }),
deferred_log_file_info,
@@ -1508,11 +1398,6 @@ impl RolloutWriterState {
meta,
cwd,
rollout_path,
state_db_ctx,
state_builder,
default_provider,
generate_memories,
thread_updated_at_touch: ThreadUpdatedAtTouch::default(),
last_logged_error: None,
}
}
@@ -1545,19 +1430,7 @@ impl RolloutWriterState {
if self.is_deferred() && self.pending_items.is_empty() {
return Ok(());
}
self.write_pending_with_recovery("shutdown").await?;
if let Some((thread_id, updated_at)) = self.thread_updated_at_touch.pending_touch.take()
&& state_db::touch_thread_updated_at(
self.state_db_ctx.as_deref(),
Some(thread_id),
updated_at,
"rollout_writer_shutdown",
)
.await
{
self.thread_updated_at_touch.mark_persisted(Instant::now());
}
Ok(())
self.write_pending_with_recovery("shutdown").await
}
async fn write_pending_with_recovery(&mut self, operation: &str) -> std::io::Result<()> {
@@ -1625,18 +1498,7 @@ impl RolloutWriterState {
let Some(session_meta) = self.meta.as_ref().cloned() else {
return Ok(());
};
write_session_meta(
self.writer.as_mut(),
session_meta,
&self.cwd,
&self.rollout_path,
self.state_db_ctx.as_deref(),
&mut self.state_builder,
self.default_provider.as_str(),
self.generate_memories,
&mut self.thread_updated_at_touch,
)
.await?;
write_session_meta(self.writer.as_mut(), session_meta, &self.cwd).await?;
self.meta = None;
Ok(())
}
@@ -1669,25 +1531,13 @@ impl RolloutWriterState {
}
if written_count > 0 {
let written_items: Vec<RolloutItem> =
self.pending_items.drain(..written_count).collect();
sync_thread_state_after_write(
self.state_db_ctx.as_deref(),
&self.rollout_path,
self.state_builder.as_ref(),
written_items.as_slice(),
self.default_provider.as_str(),
/*new_thread_memory_mode*/ None,
&mut self.thread_updated_at_touch,
)
.await;
self.pending_items.drain(..written_count);
}
write_result
}
}
#[allow(clippy::too_many_arguments)]
async fn rollout_writer(
file: Option<tokio::fs::File>,
deferred_log_file_info: Option<LogFileInfo>,
@@ -1695,22 +1545,8 @@ async fn rollout_writer(
meta: Option<SessionMeta>,
cwd: PathBuf,
rollout_path: PathBuf,
state_db_ctx: Option<StateDbHandle>,
state_builder: Option<ThreadMetadataBuilder>,
default_provider: String,
generate_memories: bool,
) -> std::io::Result<()> {
let mut state = RolloutWriterState::new(
file,
deferred_log_file_info,
meta,
cwd,
rollout_path,
state_db_ctx,
state_builder,
default_provider,
generate_memories,
);
let mut state = RolloutWriterState::new(file, deferred_log_file_info, meta, cwd, rollout_path);
// Process rollout commands
while let Some(cmd) = rx.recv().await {
@@ -1740,116 +1576,36 @@ async fn rollout_writer(
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn write_session_meta(
mut writer: Option<&mut JsonlWriter>,
session_meta: SessionMeta,
cwd: &Path,
rollout_path: &Path,
state_db_ctx: Option<&StateRuntime>,
state_builder: &mut Option<ThreadMetadataBuilder>,
default_provider: &str,
generate_memories: bool,
thread_updated_at_touch: &mut ThreadUpdatedAtTouch,
) -> std::io::Result<()> {
let git_info = collect_git_info(cwd).await.map(|info| ProtocolGitInfo {
commit_hash: info.commit_hash,
branch: info.branch,
repository_url: info.repository_url,
});
let git_info = if get_git_repo_root(cwd).is_some() {
collect_git_info(cwd).await.map(|info| ProtocolGitInfo {
commit_hash: info.commit_hash,
branch: info.branch,
repository_url: info.repository_url,
})
} else {
None
};
let session_meta_line = SessionMetaLine {
meta: session_meta,
git: git_info,
};
if state_db_ctx.is_some() {
*state_builder = metadata::builder_from_session_meta(&session_meta_line, rollout_path);
}
let rollout_item = RolloutItem::SessionMeta(session_meta_line);
if let Some(writer) = writer.as_mut() {
writer.write_rollout_item(&rollout_item).await?;
}
sync_thread_state_after_write(
state_db_ctx,
rollout_path,
state_builder.as_ref(),
std::slice::from_ref(&rollout_item),
default_provider,
(!generate_memories).then_some("disabled"),
thread_updated_at_touch,
)
.await;
Ok(())
}
async fn sync_thread_state_after_write(
state_db_ctx: Option<&StateRuntime>,
rollout_path: &Path,
state_builder: Option<&ThreadMetadataBuilder>,
items: &[RolloutItem],
default_provider: &str,
new_thread_memory_mode: Option<&str>,
thread_updated_at_touch: &mut ThreadUpdatedAtTouch,
) {
let updated_at = Utc::now();
let now = Instant::now();
if new_thread_memory_mode.is_some()
|| items
.iter()
.any(codex_state::rollout_item_affects_thread_metadata)
{
state_db::apply_rollout_items(
state_db_ctx,
rollout_path,
default_provider,
state_builder,
items,
"rollout_writer",
new_thread_memory_mode,
Some(updated_at),
)
.await;
thread_updated_at_touch.mark_persisted(now);
return;
}
let thread_id = state_builder
.map(|builder| builder.id)
.or_else(|| metadata::builder_from_items(items, rollout_path).map(|builder| builder.id));
if thread_updated_at_touch
.last_persisted_at
.is_some_and(|last_persisted_at| {
now.duration_since(last_persisted_at) < THREAD_UPDATED_AT_TOUCH_INTERVAL
})
{
thread_updated_at_touch.pending_touch = thread_id.map(|thread_id| (thread_id, updated_at));
return;
}
if state_db::touch_thread_updated_at(state_db_ctx, thread_id, updated_at, "rollout_writer")
.await
{
thread_updated_at_touch.mark_persisted(now);
return;
}
state_db::apply_rollout_items(
state_db_ctx,
rollout_path,
default_provider,
state_builder,
items,
"rollout_writer",
new_thread_memory_mode,
Some(updated_at),
)
.await;
thread_updated_at_touch.mark_persisted(now);
}
/// Append one already-filtered rollout item to an existing rollout JSONL file.
///
/// This is for metadata updates to unloaded threads. Live sessions should use
/// `RolloutRecorder::record_items` so rollout and SQLite updates remain ordered
/// `RolloutRecorder::record_canonical_items` so rollout writes remain ordered
/// with the rest of the session stream.
pub async fn append_rollout_item_to_path(
rollout_path: &Path,
+3 -224
View File
@@ -372,10 +372,7 @@ async fn recorder_materializes_on_flush_with_pending_items() -> std::io::Result<
/*thread_source*/ None,
BaseInstructions::default(),
Vec::new(),
EventPersistenceMode::Limited,
),
/*state_db_ctx*/ None,
/*state_builder*/ None,
)
.await?;
@@ -386,7 +383,7 @@ async fn recorder_materializes_on_flush_with_pending_items() -> std::io::Result<
);
recorder
.record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
.record_canonical_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
AgentMessageEvent {
message: "buffered-event".to_string(),
phase: None,
@@ -401,7 +398,7 @@ async fn recorder_materializes_on_flush_with_pending_items() -> std::io::Result<
);
recorder
.record_items(&[RolloutItem::EventMsg(EventMsg::UserMessage(
.record_canonical_items(&[RolloutItem::EventMsg(EventMsg::UserMessage(
UserMessageEvent {
message: "first-user-message".to_string(),
images: None,
@@ -453,16 +450,13 @@ async fn persist_reports_filesystem_error_and_retries_buffered_items() -> std::i
/*thread_source*/ None,
BaseInstructions::default(),
Vec::new(),
EventPersistenceMode::Limited,
),
/*state_db_ctx*/ None,
/*state_builder*/ None,
)
.await?;
let rollout_path = recorder.rollout_path().to_path_buf();
recorder
.record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
.record_canonical_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
AgentMessageEvent {
message: "buffered-before-persist".to_string(),
phase: None,
@@ -498,7 +492,6 @@ async fn persist_reports_filesystem_error_and_retries_buffered_items() -> std::i
#[tokio::test]
async fn writer_state_retries_write_error_before_reporting_flush_success() -> std::io::Result<()> {
let home = TempDir::new().expect("temp dir");
let config = test_config(home.path());
let rollout_path = home.path().join("rollout.jsonl");
File::create(&rollout_path)?;
let read_only_file = std::fs::OpenOptions::new().read(true).open(&rollout_path)?;
@@ -508,10 +501,6 @@ async fn writer_state_retries_write_error_before_reporting_flush_success() -> st
/*meta*/ None,
home.path().to_path_buf(),
rollout_path.clone(),
/*state_db_ctx*/ None,
/*state_builder*/ None,
config.model_provider_id.clone(),
config.generate_memories,
);
state.add_items(vec![RolloutItem::EventMsg(EventMsg::AgentMessage(
AgentMessageEvent {
@@ -530,216 +519,6 @@ async fn writer_state_retries_write_error_before_reporting_flush_success() -> st
Ok(())
}
#[tokio::test]
async fn metadata_irrelevant_events_coalesce_state_db_updated_at() -> std::io::Result<()> {
let home = TempDir::new().expect("temp dir");
let config = test_config(home.path());
let state_db = StateRuntime::init(home.path().to_path_buf(), config.model_provider_id.clone())
.await
.expect("state db should initialize");
state_db
.mark_backfill_complete(/*last_watermark*/ None)
.await
.expect("backfill should be complete");
let thread_id = ThreadId::new();
let recorder = RolloutRecorder::new(
&config,
RolloutRecorderParams::new(
thread_id,
/*forked_from_id*/ None,
SessionSource::Cli,
/*thread_source*/ None,
BaseInstructions::default(),
Vec::new(),
EventPersistenceMode::Limited,
),
Some(state_db.clone()),
/*state_builder*/ None,
)
.await?;
recorder
.record_items(&[RolloutItem::EventMsg(EventMsg::UserMessage(
UserMessageEvent {
message: "first-user-message".to_string(),
images: None,
local_images: Vec::new(),
text_elements: Vec::new(),
},
))])
.await?;
recorder.persist().await?;
recorder.flush().await?;
let initial_thread = state_db
.get_thread(thread_id)
.await
.expect("thread should load")
.expect("thread should exist");
let initial_updated_at = initial_thread.updated_at;
let initial_title = initial_thread.title.clone();
let initial_first_user_message = initial_thread.first_user_message.clone();
recorder
.record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
AgentMessageEvent {
message: "assistant text".to_string(),
phase: None,
memory_citation: None,
},
))])
.await?;
recorder.flush().await?;
let updated_thread = state_db
.get_thread(thread_id)
.await
.expect("thread should load after agent message")
.expect("thread should still exist");
assert_eq!(updated_thread.updated_at, initial_updated_at);
assert_eq!(updated_thread.title, initial_title);
assert_eq!(
updated_thread.first_user_message,
initial_first_user_message
);
tokio::time::sleep(THREAD_UPDATED_AT_TOUCH_INTERVAL + Duration::from_millis(10)).await;
recorder
.record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
AgentMessageEvent {
message: "more assistant text".to_string(),
phase: None,
memory_citation: None,
},
))])
.await?;
recorder.flush().await?;
let refreshed_thread = state_db
.get_thread(thread_id)
.await
.expect("thread should load after refresh")
.expect("thread should still exist");
assert!(refreshed_thread.updated_at > initial_updated_at);
assert_eq!(refreshed_thread.title, initial_title);
assert_eq!(
refreshed_thread.first_user_message,
initial_first_user_message
);
recorder.shutdown().await?;
Ok(())
}
#[tokio::test]
async fn shutdown_flushes_pending_metadata_irrelevant_updated_at() -> std::io::Result<()> {
let home = TempDir::new().expect("temp dir");
let config = test_config(home.path());
let state_db = StateRuntime::init(home.path().to_path_buf(), config.model_provider_id.clone())
.await
.expect("state db should initialize");
state_db
.mark_backfill_complete(/*last_watermark*/ None)
.await
.expect("backfill should be complete");
let thread_id = ThreadId::new();
let rollout_path = home.path().join("rollout.jsonl");
let initial_updated_at = Utc.with_ymd_and_hms(2026, 5, 7, 7, 37, 8).unwrap();
let builder = ThreadMetadataBuilder::new(
thread_id,
rollout_path.clone(),
initial_updated_at,
SessionSource::Cli,
);
state_db
.upsert_thread(&builder.build(config.model_provider_id.as_str()))
.await
.expect("thread should be inserted");
File::create(&rollout_path)?;
let rollout_file = std::fs::OpenOptions::new()
.append(true)
.open(&rollout_path)?;
let mut state = RolloutWriterState::new(
Some(tokio::fs::File::from_std(rollout_file)),
/*deferred_log_file_info*/ None,
/*meta*/ None,
home.path().to_path_buf(),
rollout_path,
Some(state_db.clone()),
Some(builder),
config.model_provider_id.clone(),
config.generate_memories,
);
let pending_updated_at = initial_updated_at + chrono::Duration::seconds(1);
state.thread_updated_at_touch.pending_touch = Some((thread_id, pending_updated_at));
state.shutdown().await?;
assert_eq!(
state_db
.get_thread(thread_id)
.await
.expect("thread should load after shutdown")
.expect("thread should still exist")
.updated_at,
pending_updated_at
);
Ok(())
}
#[tokio::test]
async fn metadata_irrelevant_events_fall_back_to_upsert_when_thread_missing() -> std::io::Result<()>
{
let home = TempDir::new().expect("temp dir");
let config = test_config(home.path());
let state_db = StateRuntime::init(home.path().to_path_buf(), config.model_provider_id.clone())
.await
.expect("state db should initialize");
let thread_id = ThreadId::new();
let rollout_path = home.path().join("rollout.jsonl");
let builder = ThreadMetadataBuilder::new(
thread_id,
rollout_path.clone(),
Utc::now(),
SessionSource::Cli,
);
let items = vec![RolloutItem::EventMsg(EventMsg::AgentMessage(
AgentMessageEvent {
message: "assistant text".to_string(),
phase: None,
memory_citation: None,
},
))];
let mut thread_updated_at_touch = ThreadUpdatedAtTouch::default();
sync_thread_state_after_write(
Some(state_db.as_ref()),
rollout_path.as_path(),
Some(&builder),
items.as_slice(),
config.model_provider_id.as_str(),
/*new_thread_memory_mode*/ None,
&mut thread_updated_at_touch,
)
.await;
let thread = state_db
.get_thread(thread_id)
.await
.expect("thread should load after fallback")
.expect("thread should be inserted after fallback");
assert_eq!(thread.id, thread_id);
Ok(())
}
#[tokio::test]
async fn list_threads_db_disabled_does_not_skip_paginated_items() -> std::io::Result<()> {
let home = TempDir::new().expect("temp dir");
+1
View File
@@ -19,6 +19,7 @@ codex-git-utils = { workspace = true }
codex-protocol = { workspace = true }
codex-rollout = { workspace = true }
codex-state = { workspace = true }
codex-utils-path = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
+35
View File
@@ -0,0 +1,35 @@
# Thread Store
`codex-thread-store` is the storage boundary for Codex threads. It defines the
`ThreadStore` trait plus local and in-memory implementations. Other storage
implementations may live outside this repository.
## Responsibilities
- `ThreadStore::append_items` is the raw canonical history append API. It does
not infer metadata from item contents.
- `ThreadStore::update_thread_metadata` is the only thread metadata write API.
It accepts a single literal metadata patch shape, regardless of whether the
caller is applying a user/API mutation or facts derived above the store from
appended history.
- `LiveThread` is the preferred API for active session persistence. It owns a
per-thread metadata sync helper, applies the rollout persistence policy,
appends canonical history, and then sends metadata patches through
`ThreadStore::update_thread_metadata`.
- `ThreadManager` routes metadata mutations for loaded and cold threads through
one entrypoint. Loaded threads use their `LiveThread`; cold threads go
directly to the store.
- `LocalThreadStore` persists history through `codex-rollout` JSONL files and
persists queryable metadata through the SQLite state database when available.
Local explicit metadata mutations also maintain JSONL/name-index compatibility
so reading old or SQLite-less local storage keeps working.
- `RolloutRecorder` is the local JSONL writer. It writes already-canonical
items for `ThreadStore::append_items`; it no longer decides metadata updates
for live thread-store appends.
- `core/session` creates or resumes `LiveThread` handles and does not need to
know whether persistence is backed by local files or another store.
## Direction
New metadata observation semantics should live above `ThreadStore`. Stores
persist explicit metadata fields, but raw history appends remain history-only.
+66 -21
View File
@@ -22,6 +22,7 @@ use crate::ReadThreadParams;
use crate::ResumeThreadParams;
use crate::StoredThread;
use crate::StoredThreadHistory;
use crate::ThreadMetadataPatch;
use crate::ThreadPage;
use crate::ThreadStore;
use crate::ThreadStoreError;
@@ -127,6 +128,7 @@ struct InMemoryThreadStoreState {
calls: InMemoryThreadStoreCalls,
created_threads: HashMap<ThreadId, CreateThreadParams>,
histories: HashMap<ThreadId, Vec<RolloutItem>>,
metadata_updates: HashMap<ThreadId, ThreadMetadataPatch>,
names: HashMap<ThreadId, Option<String>>,
rollout_paths: HashMap<PathBuf, ThreadId>,
}
@@ -271,9 +273,14 @@ impl ThreadStore for InMemoryThreadStore {
) -> ThreadStoreResult<StoredThread> {
let mut state = self.state.lock().await;
state.calls.update_thread_metadata += 1;
if let Some(name) = params.patch.name {
state.names.insert(params.thread_id, Some(name));
if let Some(name) = params.patch.name.clone() {
state.names.insert(params.thread_id, name);
}
state
.metadata_updates
.entry(params.thread_id)
.or_default()
.merge(params.patch);
stored_thread_from_state(&state, params.thread_id, /*include_history*/ false)
}
@@ -307,6 +314,7 @@ fn stored_thread_from_state(
items: history_items.clone(),
});
let name = state.names.get(&thread_id).cloned().flatten();
let metadata = state.metadata_updates.get(&thread_id);
let rollout_path = state
.rollout_paths
.iter()
@@ -316,28 +324,65 @@ fn stored_thread_from_state(
Ok(StoredThread {
thread_id,
rollout_path,
rollout_path: metadata
.and_then(|metadata| metadata.rollout_path.clone())
.or(rollout_path),
forked_from_id: created.forked_from_id,
preview: String::new(),
preview: metadata
.and_then(|metadata| metadata.preview.clone())
.unwrap_or_default(),
name,
model_provider: "test".to_string(),
model: None,
reasoning_effort: None,
created_at: Utc::now(),
updated_at: Utc::now(),
model_provider: metadata
.and_then(|metadata| metadata.model_provider.clone())
.unwrap_or_else(|| "test".to_string()),
model: metadata.and_then(|metadata| metadata.model.clone()),
reasoning_effort: metadata.and_then(|metadata| metadata.reasoning_effort),
created_at: metadata
.and_then(|metadata| metadata.created_at)
.unwrap_or_else(Utc::now),
updated_at: metadata
.and_then(|metadata| metadata.updated_at)
.unwrap_or_else(Utc::now),
archived_at: None,
cwd: PathBuf::new(),
cli_version: "test".to_string(),
source: created.source.clone(),
thread_source: created.thread_source,
agent_nickname: None,
agent_role: None,
agent_path: None,
git_info: None,
approval_mode: AskForApproval::Never,
sandbox_policy: SandboxPolicy::new_read_only_policy(),
token_usage: None,
first_user_message: None,
cwd: metadata
.and_then(|metadata| metadata.cwd.clone())
.unwrap_or_default(),
cli_version: metadata
.and_then(|metadata| metadata.cli_version.clone())
.unwrap_or_else(|| "test".to_string()),
source: metadata
.and_then(|metadata| metadata.source.clone())
.unwrap_or_else(|| created.source.clone()),
thread_source: metadata
.and_then(|metadata| metadata.thread_source)
.unwrap_or(created.thread_source),
agent_nickname: metadata.and_then(|metadata| metadata.agent_nickname.clone().flatten()),
agent_role: metadata.and_then(|metadata| metadata.agent_role.clone().flatten()),
agent_path: metadata.and_then(|metadata| metadata.agent_path.clone().flatten()),
git_info: metadata.and_then(git_info_from_patch),
approval_mode: metadata
.and_then(|metadata| metadata.approval_mode)
.unwrap_or(AskForApproval::Never),
sandbox_policy: metadata
.and_then(|metadata| metadata.sandbox_policy.clone())
.unwrap_or_else(SandboxPolicy::new_read_only_policy),
token_usage: metadata.and_then(|metadata| metadata.token_usage.clone()),
first_user_message: metadata.and_then(|metadata| metadata.first_user_message.clone()),
history,
})
}
fn git_info_from_patch(patch: &ThreadMetadataPatch) -> Option<codex_protocol::protocol::GitInfo> {
let git_info = patch.git_info.as_ref()?;
let sha = git_info.sha.clone().flatten();
let branch = git_info.branch.clone().flatten();
let origin_url = git_info.origin_url.clone().flatten();
if sha.is_none() && branch.is_none() && origin_url.is_none() {
return None;
}
Some(codex_protocol::protocol::GitInfo {
commit_hash: sha.as_deref().map(codex_git_utils::GitSha::new),
branch,
repository_url: origin_url,
})
}
+2 -1
View File
@@ -9,6 +9,7 @@ mod in_memory;
mod live_thread;
mod local;
mod store;
mod thread_metadata_sync;
mod types;
pub use error::ThreadStoreError;
@@ -22,6 +23,7 @@ pub use local::LocalThreadStoreConfig;
pub use store::ThreadStore;
pub use types::AppendThreadItemsParams;
pub use types::ArchiveThreadParams;
pub use types::ClearableField;
pub use types::CreateThreadParams;
pub use types::GitInfoPatch;
pub use types::ItemPage;
@@ -29,7 +31,6 @@ pub use types::ListItemsParams;
pub use types::ListThreadsParams;
pub use types::ListTurnsParams;
pub use types::LoadThreadHistoryParams;
pub use types::OptionalStringPatch;
pub use types::ReadThreadByRolloutPathParams;
pub use types::ReadThreadParams;
pub use types::ResumeThreadParams;
+109 -5
View File
@@ -4,6 +4,9 @@ use std::sync::Arc;
use codex_protocol::ThreadId;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::ThreadMemoryMode;
use codex_rollout::EventPersistenceMode;
use codex_rollout::persisted_rollout_items;
use tokio::sync::Mutex;
use tracing::warn;
use crate::AppendThreadItemsParams;
@@ -14,10 +17,12 @@ use crate::ReadThreadParams;
use crate::ResumeThreadParams;
use crate::StoredThread;
use crate::StoredThreadHistory;
use crate::ThreadEventPersistenceMode;
use crate::ThreadMetadataPatch;
use crate::ThreadStore;
use crate::ThreadStoreResult;
use crate::UpdateThreadMetadataParams;
use crate::thread_metadata_sync::ThreadMetadataSync;
/// Handle for an active thread's persistence lifecycle.
///
@@ -28,6 +33,8 @@ use crate::UpdateThreadMetadataParams;
pub struct LiveThread {
thread_id: ThreadId,
thread_store: Arc<dyn ThreadStore>,
event_persistence_mode: EventPersistenceMode,
metadata_sync: Arc<Mutex<ThreadMetadataSync>>,
}
/// Owns a live thread while session initialization is still fallible.
@@ -85,43 +92,96 @@ impl LiveThread {
params: CreateThreadParams,
) -> ThreadStoreResult<Self> {
let thread_id = params.thread_id;
let event_persistence_mode = event_persistence_mode(params.event_persistence_mode);
let metadata_sync = ThreadMetadataSync::for_create(&params).await;
thread_store.create_thread(params).await?;
Ok(Self {
thread_id,
thread_store,
event_persistence_mode,
metadata_sync: Arc::new(Mutex::new(metadata_sync)),
})
}
pub async fn resume(
thread_store: Arc<dyn ThreadStore>,
params: ResumeThreadParams,
mut params: ResumeThreadParams,
) -> ThreadStoreResult<Self> {
let thread_id = params.thread_id;
thread_store.resume_thread(params).await?;
let event_persistence_mode = event_persistence_mode(params.event_persistence_mode);
let should_load_history = params.history.is_none();
let include_archived = params.include_archived;
thread_store.resume_thread(params.clone()).await?;
if should_load_history {
match thread_store
.load_history(LoadThreadHistoryParams {
thread_id,
include_archived,
})
.await
{
Ok(history) => params.history = Some(history.items),
Err(err) => {
let _ = thread_store.discard_thread(thread_id).await;
return Err(err);
}
}
}
let metadata_sync = ThreadMetadataSync::for_resume(&params);
Ok(Self {
thread_id,
thread_store,
event_persistence_mode,
metadata_sync: Arc::new(Mutex::new(metadata_sync)),
})
}
pub async fn append_items(&self, items: &[RolloutItem]) -> ThreadStoreResult<()> {
let canonical_items = persisted_rollout_items(items, self.event_persistence_mode);
if canonical_items.is_empty() {
return Ok(());
}
self.thread_store
.append_items(AppendThreadItemsParams {
thread_id: self.thread_id,
items: items.to_vec(),
items: canonical_items.clone(),
})
.await?;
let update = self
.metadata_sync
.lock()
.await
.observe_appended_items(canonical_items.as_slice());
if let Some(update) = update {
self.thread_store
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id: self.thread_id,
patch: update.patch.clone(),
include_archived: true,
})
.await?;
self.metadata_sync
.lock()
.await
.mark_pending_update_applied(&update);
}
Ok(())
}
pub async fn persist(&self) -> ThreadStoreResult<()> {
self.thread_store.persist_thread(self.thread_id).await
self.thread_store.persist_thread(self.thread_id).await?;
self.flush_pending_metadata_update().await
}
pub async fn flush(&self) -> ThreadStoreResult<()> {
self.thread_store.flush_thread(self.thread_id).await
self.thread_store.flush_thread(self.thread_id).await?;
self.flush_pending_metadata_update_for_existing_history()
.await
}
pub async fn shutdown(&self) -> ThreadStoreResult<()> {
self.flush_pending_metadata_update_for_existing_history()
.await?;
self.thread_store.shutdown_thread(self.thread_id).await
}
@@ -160,6 +220,7 @@ impl LiveThread {
mode: ThreadMemoryMode,
include_archived: bool,
) -> ThreadStoreResult<()> {
self.flush_pending_metadata_update().await?;
self.thread_store
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id: self.thread_id,
@@ -178,6 +239,7 @@ impl LiveThread {
patch: ThreadMetadataPatch,
include_archived: bool,
) -> ThreadStoreResult<StoredThread> {
self.flush_pending_metadata_update().await?;
self.thread_store
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id: self.thread_id,
@@ -203,4 +265,46 @@ impl LiveThread {
.await
.map(Some)
}
async fn flush_pending_metadata_update(&self) -> ThreadStoreResult<()> {
let update = self.metadata_sync.lock().await.take_pending_update();
self.apply_pending_metadata_update(update).await
}
async fn flush_pending_metadata_update_for_existing_history(&self) -> ThreadStoreResult<()> {
let update = self
.metadata_sync
.lock()
.await
.take_pending_update_for_existing_history();
self.apply_pending_metadata_update(update).await
}
async fn apply_pending_metadata_update(
&self,
update: Option<crate::thread_metadata_sync::PendingThreadMetadataPatch>,
) -> ThreadStoreResult<()> {
let Some(update) = update else {
return Ok(());
};
self.thread_store
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id: self.thread_id,
patch: update.patch.clone(),
include_archived: true,
})
.await?;
self.metadata_sync
.lock()
.await
.mark_pending_update_applied(&update);
Ok(())
}
}
fn event_persistence_mode(mode: ThreadEventPersistenceMode) -> EventPersistenceMode {
match mode {
ThreadEventPersistenceMode::Limited => EventPersistenceMode::Limited,
ThreadEventPersistenceMode::Extended => EventPersistenceMode::Extended,
}
}
@@ -1,10 +1,8 @@
use super::LocalThreadStore;
use crate::CreateThreadParams;
use crate::ThreadEventPersistenceMode;
use crate::ThreadStoreError;
use crate::ThreadStoreResult;
use codex_protocol::protocol::ThreadMemoryMode;
use codex_rollout::EventPersistenceMode;
use codex_rollout::RolloutConfig;
use codex_rollout::RolloutRecorder;
use codex_rollout::RolloutRecorderParams;
@@ -27,7 +25,6 @@ pub(super) async fn create_thread(
model_provider_id: params.metadata.model_provider.clone(),
generate_memories: matches!(params.metadata.memory_mode, ThreadMemoryMode::Enabled),
};
let state_db_ctx = store.state_db().await;
let recorder = RolloutRecorder::new(
&config,
RolloutRecorderParams::new(
@@ -37,10 +34,7 @@ pub(super) async fn create_thread(
params.thread_source,
params.base_instructions,
params.dynamic_tools,
event_persistence_mode(params.event_persistence_mode),
),
state_db_ctx,
/*state_builder*/ None,
)
.await
.map_err(|err| ThreadStoreError::Internal {
@@ -49,10 +43,3 @@ pub(super) async fn create_thread(
Ok(recorder)
}
pub(super) fn event_persistence_mode(mode: ThreadEventPersistenceMode) -> EventPersistenceMode {
match mode {
ThreadEventPersistenceMode::Limited => EventPersistenceMode::Limited,
ThreadEventPersistenceMode::Extended => EventPersistenceMode::Extended,
}
}
+66 -33
View File
@@ -5,7 +5,7 @@ use codex_protocol::protocol::ThreadMemoryMode;
use codex_rollout::RolloutConfig;
use codex_rollout::RolloutRecorder;
use codex_rollout::RolloutRecorderParams;
use codex_rollout::builder_from_items;
use tracing::warn;
use super::LocalThreadStore;
use super::create_thread;
@@ -31,8 +31,8 @@ pub(super) async fn resume_thread(
params: ResumeThreadParams,
) -> ThreadStoreResult<()> {
store.ensure_live_recorder_absent(params.thread_id).await?;
let (rollout_path, history) = match (params.rollout_path, params.history) {
(Some(rollout_path), history) => (rollout_path, history),
let rollout_path = match (params.rollout_path, params.history) {
(Some(rollout_path), _history) => rollout_path,
(None, history) => {
let thread = super::read_thread::read_thread(
store,
@@ -43,20 +43,14 @@ pub(super) async fn resume_thread(
},
)
.await?;
let rollout_path = thread
thread
.rollout_path
.ok_or_else(|| ThreadStoreError::Internal {
message: format!("thread {} does not have a rollout path", params.thread_id),
})?;
(
rollout_path,
history.or_else(|| thread.history.map(|history| history.items)),
)
})?
}
};
let state_builder = history
.as_deref()
.and_then(|items| builder_from_items(items, rollout_path.as_path()));
let cwd = params
.metadata
.cwd
@@ -71,20 +65,11 @@ pub(super) async fn resume_thread(
model_provider_id: params.metadata.model_provider.clone(),
generate_memories: matches!(params.metadata.memory_mode, ThreadMemoryMode::Enabled),
};
let state_db_ctx = store.state_db().await;
let recorder = RolloutRecorder::new(
&config,
RolloutRecorderParams::resume(
rollout_path,
create_thread::event_persistence_mode(params.event_persistence_mode),
),
state_db_ctx,
state_builder,
)
.await
.map_err(|err| ThreadStoreError::Internal {
message: format!("failed to resume local thread recorder: {err}"),
})?;
let recorder = RolloutRecorder::new(&config, RolloutRecorderParams::resume(rollout_path))
.await
.map_err(|err| ThreadStoreError::Internal {
message: format!("failed to resume local thread recorder: {err}"),
})?;
store.insert_live_recorder(params.thread_id, recorder).await
}
@@ -92,12 +77,14 @@ pub(super) async fn append_items(
store: &LocalThreadStore,
params: AppendThreadItemsParams,
) -> ThreadStoreResult<()> {
store
.live_recorder(params.thread_id)
.await?
.record_items(params.items.as_slice())
let recorder = store.live_recorder(params.thread_id).await?;
recorder
.record_canonical_items(params.items.as_slice())
.await
.map_err(thread_store_io_error)
.map_err(thread_store_io_error)?;
// LiveThread applies metadata immediately after append_items returns. Wait for the local
// writer so SQLite never gets ahead of JSONL for accepted live appends.
recorder.flush().await.map_err(thread_store_io_error)
}
pub(super) async fn persist_thread(
@@ -109,7 +96,8 @@ pub(super) async fn persist_thread(
.await?
.persist()
.await
.map_err(thread_store_io_error)
.map_err(thread_store_io_error)?;
sync_materialized_rollout_path(store, thread_id).await
}
pub(super) async fn flush_thread(
@@ -121,7 +109,8 @@ pub(super) async fn flush_thread(
.await?
.flush()
.await
.map_err(thread_store_io_error)
.map_err(thread_store_io_error)?;
sync_materialized_rollout_path(store, thread_id).await
}
pub(super) async fn shutdown_thread(
@@ -130,6 +119,7 @@ pub(super) async fn shutdown_thread(
) -> ThreadStoreResult<()> {
let recorder = store.live_recorder(thread_id).await?;
recorder.shutdown().await.map_err(thread_store_io_error)?;
sync_materialized_rollout_path(store, thread_id).await?;
store.live_recorders.lock().await.remove(&thread_id);
Ok(())
}
@@ -161,6 +151,49 @@ pub(super) async fn rollout_path(
.to_path_buf())
}
async fn sync_materialized_rollout_path(
store: &LocalThreadStore,
thread_id: ThreadId,
) -> ThreadStoreResult<()> {
let rollout_path = rollout_path(store, thread_id).await?;
if !tokio::fs::try_exists(rollout_path.as_path())
.await
.unwrap_or(false)
{
return Ok(());
}
let Some(state_db) = store.state_db().await else {
return Ok(());
};
let result: ThreadStoreResult<()> = async {
let Some(mut metadata) =
state_db
.get_thread(thread_id)
.await
.map_err(|err| ThreadStoreError::Internal {
message: format!("failed to read thread metadata for {thread_id}: {err}"),
})?
else {
return Ok(());
};
if metadata.rollout_path != rollout_path {
metadata.rollout_path = rollout_path;
state_db
.upsert_thread(&metadata)
.await
.map_err(|err| ThreadStoreError::Internal {
message: format!("failed to update thread metadata for {thread_id}: {err}"),
})?;
}
Ok(())
}
.await;
if let Err(err) = result {
warn!("failed to sync materialized rollout path for thread {thread_id}: {err}");
}
Ok(())
}
fn thread_store_io_error(err: std::io::Error) -> ThreadStoreError {
ThreadStoreError::Internal {
message: err.to_string(),
+276
View File
@@ -37,6 +37,18 @@ use crate::ThreadStoreResult;
use crate::UpdateThreadMetadataParams;
/// Local filesystem/SQLite-backed implementation of [`ThreadStore`].
///
/// Local storage has two compatibility surfaces. Rollout JSONL files are the
/// durable replay format and remain readable without SQLite, including older
/// files that encode metadata in `SessionMeta` items and name-index entries.
/// The SQLite state DB, when available, is the queryable metadata index used by
/// list/read paths for fast lookup.
///
/// Live appends still write canonical JSONL history, but append-derived
/// metadata is observed above the store and applied through
/// [`ThreadStore::update_thread_metadata`]. This implementation applies that
/// patch literally to SQLite while keeping the JSONL/name-index compatibility
/// behavior needed for SQLite-less reads, repair, and old local rollout files.
#[derive(Clone)]
pub struct LocalThreadStore {
pub(super) config: LocalThreadStoreConfig,
@@ -270,6 +282,8 @@ impl ThreadStore for LocalThreadStore {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use codex_protocol::ThreadId;
use codex_protocol::models::BaseInstructions;
use codex_protocol::protocol::EventMsg;
@@ -280,6 +294,7 @@ mod tests {
use tempfile::TempDir;
use super::*;
use crate::LiveThread;
use crate::ThreadEventPersistenceMode;
use crate::ThreadPersistenceMetadata;
use crate::local::test_support::test_config;
@@ -335,6 +350,267 @@ mod tests {
);
}
#[tokio::test]
async fn raw_append_items_does_not_update_sqlite_metadata() {
// This pins the ThreadStore contract: raw appends are history-only. Callers that need
// metadata updates must use LiveThread or call update_thread_metadata explicitly.
let home = TempDir::new().expect("temp dir");
let config = test_config(home.path());
let runtime = codex_state::StateRuntime::init(
config.sqlite_home.clone(),
config.default_model_provider_id.clone(),
)
.await
.expect("state db should initialize");
let store = LocalThreadStore::new(config, Some(runtime.clone()));
let thread_id = ThreadId::default();
store
.create_thread(create_thread_params(thread_id))
.await
.expect("create live thread");
store
.append_items(AppendThreadItemsParams {
thread_id,
items: vec![user_message_item("raw append")],
})
.await
.expect("append raw item");
store.flush_thread(thread_id).await.expect("flush thread");
assert_eq!(
runtime
.get_thread(thread_id)
.await
.expect("sqlite metadata read"),
None
);
}
#[tokio::test]
async fn live_thread_observes_appended_items_into_sqlite_metadata() {
let home = TempDir::new().expect("temp dir");
let config = test_config(home.path());
let runtime = codex_state::StateRuntime::init(
config.sqlite_home.clone(),
config.default_model_provider_id.clone(),
)
.await
.expect("state db should initialize");
let store = Arc::new(LocalThreadStore::new(config, Some(runtime.clone())));
let thread_id = ThreadId::default();
let live_thread = LiveThread::create(store.clone(), create_thread_params(thread_id))
.await
.expect("create live thread");
live_thread
.append_items(&[user_message_item("observed append")])
.await
.expect("append observed item");
live_thread.flush().await.expect("flush thread");
let metadata = runtime
.get_thread(thread_id)
.await
.expect("sqlite metadata read")
.expect("sqlite metadata");
assert_eq!(
metadata.first_user_message.as_deref(),
Some("observed append")
);
assert_eq!(metadata.preview.as_deref(), Some("observed append"));
assert_eq!(metadata.title, "observed append");
}
#[tokio::test]
async fn live_thread_shutdown_does_not_materialize_empty_thread_metadata() {
let home = TempDir::new().expect("temp dir");
let config = test_config(home.path());
let runtime = codex_state::StateRuntime::init(
config.sqlite_home.clone(),
config.default_model_provider_id.clone(),
)
.await
.expect("state db should initialize");
let store = Arc::new(LocalThreadStore::new(config, Some(runtime.clone())));
let thread_id = ThreadId::default();
let live_thread = LiveThread::create(store.clone(), create_thread_params(thread_id))
.await
.expect("create live thread");
let rollout_path = store
.live_rollout_path(thread_id)
.await
.expect("live rollout path");
live_thread.shutdown().await.expect("shutdown thread");
assert!(
!tokio::fs::try_exists(rollout_path.as_path())
.await
.expect("rollout path should be checkable")
);
assert_eq!(
runtime
.get_thread(thread_id)
.await
.expect("sqlite metadata read"),
None
);
}
#[tokio::test]
async fn live_thread_shutdown_with_buffered_items_materializes_before_metadata_read() {
let home = TempDir::new().expect("temp dir");
let config = test_config(home.path());
let runtime = codex_state::StateRuntime::init(
config.sqlite_home.clone(),
config.default_model_provider_id.clone(),
)
.await
.expect("state db should initialize");
let store = Arc::new(LocalThreadStore::new(config, Some(runtime.clone())));
let thread_id = ThreadId::default();
let live_thread = LiveThread::create(store.clone(), create_thread_params(thread_id))
.await
.expect("create live thread");
let rollout_path = store
.live_rollout_path(thread_id)
.await
.expect("live rollout path");
live_thread
.append_items(&[RolloutItem::EventMsg(EventMsg::TokenCount(
codex_protocol::protocol::TokenCountEvent {
info: None,
rate_limits: None,
},
))])
.await
.expect("append metadata-only item");
live_thread.shutdown().await.expect("shutdown thread");
assert!(
tokio::fs::try_exists(rollout_path.as_path())
.await
.expect("rollout path should be checkable")
);
let metadata = runtime
.get_thread(thread_id)
.await
.expect("sqlite metadata read")
.expect("sqlite metadata");
assert_eq!(metadata.rollout_path, rollout_path);
}
#[tokio::test]
async fn live_thread_resume_loads_history_before_observing_metadata() {
let home = TempDir::new().expect("temp dir");
let config = test_config(home.path());
let runtime = codex_state::StateRuntime::init(
config.sqlite_home.clone(),
config.default_model_provider_id.clone(),
)
.await
.expect("state db should initialize");
let store = Arc::new(LocalThreadStore::new(config, Some(runtime.clone())));
let uuid = uuid::Uuid::from_u128(401);
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
let rollout_path =
write_session_file(home.path(), "2025-01-03T17-00-00", uuid).expect("session file");
let live_thread = LiveThread::resume(
store,
ResumeThreadParams {
thread_id,
rollout_path: Some(rollout_path),
history: None,
include_archived: false,
metadata: ThreadPersistenceMetadata {
cwd: Some(home.path().to_path_buf()),
model_provider: "different-provider".to_string(),
memory_mode: ThreadMemoryMode::Enabled,
},
event_persistence_mode: ThreadEventPersistenceMode::Limited,
},
)
.await
.expect("resume live thread");
live_thread
.append_items(&[user_message_item("new live append")])
.await
.expect("append after resume");
let metadata = runtime
.get_thread(thread_id)
.await
.expect("sqlite metadata read")
.expect("sqlite metadata");
assert_eq!(
metadata.created_at.to_rfc3339(),
"2025-01-03T17:00:00+00:00"
);
assert_eq!(metadata.model_provider, "test-provider");
assert_eq!(
metadata.first_user_message.as_deref(),
Some("Hello from user")
);
}
#[tokio::test]
async fn live_thread_resume_loads_history_from_explicit_external_rollout_path() {
let home = TempDir::new().expect("temp dir");
let external_home = TempDir::new().expect("external temp dir");
let config = test_config(home.path());
let runtime = codex_state::StateRuntime::init(
config.sqlite_home.clone(),
config.default_model_provider_id.clone(),
)
.await
.expect("state db should initialize");
let store = Arc::new(LocalThreadStore::new(config, Some(runtime.clone())));
let uuid = uuid::Uuid::from_u128(402);
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
let rollout_path = write_session_file(external_home.path(), "2025-01-03T17-30-00", uuid)
.expect("external session file");
let live_thread = LiveThread::resume(
store,
ResumeThreadParams {
thread_id,
rollout_path: Some(rollout_path),
history: None,
include_archived: false,
metadata: ThreadPersistenceMetadata {
cwd: Some(home.path().to_path_buf()),
model_provider: "different-provider".to_string(),
memory_mode: ThreadMemoryMode::Enabled,
},
event_persistence_mode: ThreadEventPersistenceMode::Limited,
},
)
.await
.expect("resume external live thread");
live_thread
.append_items(&[user_message_item("new external append")])
.await
.expect("append after external resume");
let metadata = runtime
.get_thread(thread_id)
.await
.expect("sqlite metadata read")
.expect("sqlite metadata");
assert_eq!(
metadata.created_at.to_rfc3339(),
"2025-01-03T17:30:00+00:00"
);
assert_eq!(metadata.model_provider, "test-provider");
assert_eq!(
metadata.first_user_message.as_deref(),
Some("Hello from user")
);
}
#[tokio::test]
async fn create_thread_rejects_missing_cwd() {
let home = TempDir::new().expect("temp dir");
@@ -273,7 +273,8 @@ async fn stored_thread_from_sqlite_metadata(
None => find_thread_name_by_id(store.config.codex_home.as_path(), &metadata.id)
.await
.ok()
.flatten(),
.flatten()
.filter(|title| !title.trim().is_empty()),
};
let session_meta = read_session_meta_line(metadata.rollout_path.as_path())
.await
@@ -1,9 +1,11 @@
use std::path::Path;
use std::path::PathBuf;
use chrono::Utc;
use codex_protocol::ThreadId;
use codex_protocol::protocol::GitInfo;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::ThreadMemoryMode;
use codex_rollout::ARCHIVED_SESSIONS_SUBDIR;
use codex_rollout::append_rollout_item_to_path;
@@ -11,6 +13,7 @@ use codex_rollout::append_thread_name;
use codex_rollout::find_archived_thread_path_by_id_str;
use codex_rollout::find_thread_path_by_id_str;
use codex_rollout::read_session_meta_line;
use codex_state::ThreadMetadataBuilder;
use super::LocalThreadStore;
use super::helpers::git_info_from_parts;
@@ -18,6 +21,7 @@ use super::live_writer;
use crate::GitInfoPatch;
use crate::ReadThreadParams;
use crate::StoredThread;
use crate::ThreadMetadataPatch;
use crate::ThreadStoreError;
use crate::ThreadStoreResult;
use crate::UpdateThreadMetadataParams;
@@ -32,25 +36,35 @@ pub(super) async fn update_thread_metadata(
store: &LocalThreadStore,
params: UpdateThreadMetadataParams,
) -> ThreadStoreResult<StoredThread> {
let field_count = usize::from(params.patch.name.is_some())
+ usize::from(params.patch.memory_mode.is_some())
+ usize::from(params.patch.git_info.is_some());
if field_count > 1 {
return Err(ThreadStoreError::InvalidRequest {
message: "local thread store applies one metadata field per patch in this slice"
.to_string(),
});
let thread_id = params.thread_id;
let patch = params.patch;
if patch.is_empty() {
return read_thread::read_thread(
store,
ReadThreadParams {
thread_id,
include_archived: params.include_archived,
include_history: false,
},
)
.await;
}
let needs_rollout_compat = needs_rollout_compatibility_update(&patch);
let updated =
apply_metadata_update(store, thread_id, patch.clone(), params.include_archived).await?;
if !needs_rollout_compat {
return Ok(updated);
}
let thread_id = params.thread_id;
if live_writer::rollout_path(store, thread_id).await.is_ok() {
live_writer::persist_thread(store, thread_id).await?;
}
let resolved_rollout_path =
resolve_rollout_path(store, thread_id, params.include_archived).await?;
let name = params.patch.name;
let git_info = params.patch.git_info;
if let Some(memory_mode) = params.patch.memory_mode {
let name = patch.name;
let git_info = patch.git_info;
if let Some(memory_mode) = patch.memory_mode {
apply_thread_memory_mode(resolved_rollout_path.path.as_path(), thread_id, memory_mode)
.await?;
}
@@ -68,7 +82,7 @@ pub(super) async fn update_thread_metadata(
.await;
if let Some(name) = name {
apply_thread_name(store, thread_id, name).await?;
apply_thread_name(store, thread_id, name.unwrap_or_default()).await?;
}
let resolved_git_info = match git_info {
@@ -150,6 +164,218 @@ pub(super) async fn update_thread_metadata(
Ok(thread)
}
async fn apply_metadata_update(
store: &LocalThreadStore,
thread_id: ThreadId,
patch: ThreadMetadataPatch,
include_archived: bool,
) -> ThreadStoreResult<StoredThread> {
let live_rollout_path = live_writer::rollout_path(store, thread_id).await.ok();
let mut rollout_path = patch.rollout_path.clone().or(live_rollout_path);
let mut rollout_path_archived = rollout_path
.as_deref()
.is_some_and(|path| rollout_path_is_archived(store, path));
let state_db = store.state_db().await;
let sqlite_write_result: ThreadStoreResult<()> = if let Some(state_db) = state_db.as_ref() {
let patch = patch.clone();
async {
let existing =
state_db
.get_thread(thread_id)
.await
.map_err(|err| ThreadStoreError::Internal {
message: format!("failed to read thread metadata for {thread_id}: {err}"),
})?;
if existing.is_none() && rollout_path.is_none() {
let resolved = resolve_rollout_path(store, thread_id, include_archived).await?;
rollout_path_archived = resolved.archived;
rollout_path = Some(resolved.path);
}
let mut metadata = existing.clone().unwrap_or_else(|| {
let created_at = patch
.created_at
.or(patch.updated_at)
.unwrap_or_else(Utc::now);
let mut builder = ThreadMetadataBuilder::new(
thread_id,
rollout_path.clone().unwrap_or_default(),
created_at,
patch.source.clone().unwrap_or(SessionSource::Unknown),
);
builder.model_provider = patch.model_provider.clone();
builder.thread_source = patch.thread_source.flatten();
builder.agent_nickname = patch.agent_nickname.clone().flatten();
builder.agent_role = patch.agent_role.clone().flatten();
builder.agent_path = patch.agent_path.clone().flatten();
builder.cwd = patch.cwd.clone().map(normalize_cwd).unwrap_or_default();
builder.cli_version = patch.cli_version.clone();
let mut metadata = builder.build(store.config.default_model_provider_id.as_str());
if rollout_path_archived {
metadata.archived_at = Some(metadata.updated_at);
}
metadata
});
if let Some(rollout_path) = rollout_path {
metadata.rollout_path = rollout_path;
}
if let Some(preview) = patch.preview {
metadata.preview = Some(preview);
}
if let Some(name) = patch.name {
metadata.title = name.unwrap_or_default();
}
if let Some(title) = patch.title {
metadata.title = title;
}
if let Some(model_provider) = patch.model_provider {
metadata.model_provider = model_provider;
}
if let Some(model) = patch.model {
metadata.model = Some(model);
}
if let Some(reasoning_effort) = patch.reasoning_effort {
metadata.reasoning_effort = Some(reasoning_effort);
}
if let Some(created_at) = patch.created_at {
metadata.created_at = created_at;
}
if let Some(updated_at) = patch.updated_at {
metadata.updated_at = updated_at;
}
if let Some(source) = patch.source {
metadata.source = enum_to_string(&source);
}
if let Some(thread_source) = patch.thread_source {
metadata.thread_source = thread_source;
}
if let Some(agent_nickname) = patch.agent_nickname {
metadata.agent_nickname = agent_nickname;
}
if let Some(agent_role) = patch.agent_role {
metadata.agent_role = agent_role;
}
if let Some(agent_path) = patch.agent_path {
metadata.agent_path = agent_path;
}
if let Some(cwd) = patch.cwd {
metadata.cwd = normalize_cwd(cwd);
}
if let Some(cli_version) = patch.cli_version {
metadata.cli_version = cli_version;
}
if let Some(approval_mode) = patch.approval_mode {
metadata.approval_mode = enum_to_string(&approval_mode);
}
if let Some(sandbox_policy) = patch.sandbox_policy {
metadata.sandbox_policy = enum_to_string(&sandbox_policy);
}
if let Some(token_usage) = patch.token_usage {
metadata.tokens_used = token_usage.total_tokens.max(0);
}
if let Some(first_user_message) = patch.first_user_message {
metadata.first_user_message = Some(first_user_message);
}
if let Some(git_info) = patch.git_info {
let existing_git_info = git_info_from_parts(
metadata.git_sha.clone(),
metadata.git_branch.clone(),
metadata.git_origin_url.clone(),
);
let (sha, branch, origin_url) = resolve_git_info_patch(existing_git_info, git_info);
metadata.git_sha = sha;
metadata.git_branch = branch;
metadata.git_origin_url = origin_url;
}
state_db
.upsert_thread(&metadata)
.await
.map_err(|err| ThreadStoreError::Internal {
message: format!("failed to update thread metadata for {thread_id}: {err}"),
})?;
if let Some(memory_mode) = patch.memory_mode {
state_db
.set_thread_memory_mode(thread_id, memory_mode_as_str(memory_mode))
.await
.map_err(|err| ThreadStoreError::Internal {
message: format!("failed to update memory mode for {thread_id}: {err}"),
})?;
}
if let Some(dynamic_tools) = patch.dynamic_tools {
state_db
.persist_dynamic_tools(thread_id, Some(dynamic_tools.as_slice()))
.await
.map_err(|err| ThreadStoreError::Internal {
message: format!("failed to update dynamic tools for {thread_id}: {err}"),
})?;
}
Ok(())
}
.await
} else {
Ok(())
};
match (state_db.is_some(), sqlite_write_result) {
(true, Ok(())) => {}
(true, Err(err)) => return Err(err),
(false, Ok(())) => {}
(false, Err(err)) => return Err(err),
}
read_thread::read_thread(
store,
ReadThreadParams {
thread_id,
include_archived,
include_history: false,
},
)
.await
}
fn needs_rollout_compatibility_update(patch: &ThreadMetadataPatch) -> bool {
if patch.name.is_some() {
return true;
}
if patch.memory_mode.is_none() && patch.git_info.is_none() {
return false;
}
!has_observed_metadata_facts(patch)
}
fn has_observed_metadata_facts(patch: &ThreadMetadataPatch) -> bool {
patch.rollout_path.is_some()
|| patch.preview.is_some()
|| patch.title.is_some()
|| patch.model_provider.is_some()
|| patch.model.is_some()
|| patch.reasoning_effort.is_some()
|| patch.created_at.is_some()
|| patch.source.is_some()
|| patch.thread_source.is_some()
|| patch.agent_nickname.is_some()
|| patch.agent_role.is_some()
|| patch.agent_path.is_some()
|| patch.cwd.is_some()
|| patch.cli_version.is_some()
|| patch.approval_mode.is_some()
|| patch.sandbox_policy.is_some()
|| patch.token_usage.is_some()
|| patch.first_user_message.is_some()
|| patch.dynamic_tools.is_some()
}
fn enum_to_string<T: serde::Serialize>(value: &T) -> String {
match serde_json::to_value(value) {
Ok(serde_json::Value::String(value)) => value,
Ok(other) => other.to_string(),
Err(_) => String::new(),
}
}
fn normalize_cwd(cwd: PathBuf) -> PathBuf {
codex_utils_path::normalize_for_path_comparison(cwd.as_path()).unwrap_or(cwd)
}
async fn apply_thread_git_info(
store: &LocalThreadStore,
thread_id: ThreadId,
@@ -363,10 +589,13 @@ mod tests {
use super::*;
use crate::GitInfoPatch;
use crate::ListThreadsParams;
use crate::ResumeThreadParams;
use crate::SortDirection;
use crate::ThreadEventPersistenceMode;
use crate::ThreadMetadataPatch;
use crate::ThreadPersistenceMetadata;
use crate::ThreadSortKey;
use crate::ThreadStore;
use crate::local::LocalThreadStore;
use crate::local::test_support::test_config;
@@ -385,7 +614,7 @@ mod tests {
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id,
patch: ThreadMetadataPatch {
name: Some("A sharper name".to_string()),
name: Some(Some("A sharper name".to_string())),
..Default::default()
},
include_archived: false,
@@ -855,44 +1084,306 @@ mod tests {
}
#[tokio::test]
async fn update_thread_metadata_rejects_multi_field_patch_without_partial_write() {
async fn update_thread_metadata_applies_combined_explicit_patch() {
let home = TempDir::new().expect("temp dir");
let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None);
let config = test_config(home.path());
let runtime = codex_state::StateRuntime::init(
home.path().to_path_buf(),
config.default_model_provider_id.clone(),
)
.await
.expect("state db should initialize");
let store = LocalThreadStore::new(config, Some(runtime.clone()));
let uuid = Uuid::from_u128(305);
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
let path =
write_session_file(home.path(), "2025-01-03T15-30-00", uuid).expect("session file");
let original = std::fs::read_to_string(&path).expect("read rollout");
let err = store
let thread = store
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id,
patch: ThreadMetadataPatch {
name: Some("Should not persist".to_string()),
name: Some(Some("Combined metadata".to_string())),
memory_mode: Some(ThreadMemoryMode::Disabled),
git_info: Some(GitInfoPatch {
branch: Some(Some("combined".to_string())),
..Default::default()
}),
..Default::default()
},
include_archived: false,
})
.await
.expect_err("multi-field patch should fail");
.expect("combined patch should apply");
assert!(matches!(err, ThreadStoreError::InvalidRequest { .. }));
assert_eq!(thread.name.as_deref(), Some("Combined metadata"));
assert_eq!(
std::fs::read_to_string(&path).expect("read rollout"),
original
thread.git_info.expect("git info").branch.as_deref(),
Some("combined")
);
let appended = last_rollout_item(path.as_path());
assert_eq!(appended["type"], "session_meta");
assert_eq!(appended["payload"]["memory_mode"], "disabled");
assert_eq!(appended["payload"]["git"]["branch"], "combined");
let latest_name = codex_rollout::find_thread_name_by_id(home.path(), &thread_id)
.await
.expect("find thread name");
assert_eq!(latest_name, None);
assert_eq!(latest_name.as_deref(), Some("Combined metadata"));
let memory_mode = runtime
.get_thread_memory_mode(thread_id)
.await
.expect("thread memory mode should be readable");
assert_eq!(memory_mode.as_deref(), Some("disabled"));
}
#[tokio::test]
async fn metadata_patch_applies_title_over_existing_name() {
let home = TempDir::new().expect("temp dir");
let config = test_config(home.path());
let runtime = codex_state::StateRuntime::init(
home.path().to_path_buf(),
config.default_model_provider_id.clone(),
)
.await
.expect("state db should initialize");
let store = LocalThreadStore::new(config, Some(runtime));
let uuid = Uuid::from_u128(306);
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
write_session_file(home.path(), "2025-01-03T15-45-00", uuid).expect("session file");
store
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id,
patch: ThreadMetadataPatch {
name: Some(Some("User chosen name".to_string())),
..Default::default()
},
include_archived: false,
})
.await
.expect("set explicit name");
let thread = store
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id,
patch: ThreadMetadataPatch {
title: Some("Derived first message".to_string()),
preview: Some("Derived first message".to_string()),
..Default::default()
},
include_archived: false,
})
.await
.expect("apply observed metadata");
assert_eq!(thread.name.as_deref(), Some("Derived first message"));
}
#[tokio::test]
async fn metadata_patch_applies_latest_preview_and_first_user_message() {
let home = TempDir::new().expect("temp dir");
let config = test_config(home.path());
let runtime = codex_state::StateRuntime::init(
home.path().to_path_buf(),
config.default_model_provider_id.clone(),
)
.await
.expect("state db should initialize");
let store = LocalThreadStore::new(config, Some(runtime.clone()));
let uuid = Uuid::from_u128(313);
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
write_session_file(home.path(), "2025-01-03T19-00-00", uuid).expect("session file");
store
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id,
patch: ThreadMetadataPatch {
preview: Some("Original preview".to_string()),
first_user_message: Some("Original first message".to_string()),
..Default::default()
},
include_archived: false,
})
.await
.expect("set observed metadata");
let thread = store
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id,
patch: ThreadMetadataPatch {
preview: Some("Later preview".to_string()),
first_user_message: Some("Later first message".to_string()),
..Default::default()
},
include_archived: false,
})
.await
.expect("apply later observed metadata");
assert_eq!(thread.preview, "Hello from user");
assert_eq!(
thread.first_user_message.as_deref(),
Some("Hello from user")
);
let metadata = runtime
.get_thread(thread_id)
.await
.expect("sqlite metadata read")
.expect("sqlite metadata");
assert_eq!(metadata.preview.as_deref(), Some("Later preview"));
assert_eq!(
metadata.first_user_message.as_deref(),
Some("Later first message")
);
}
#[tokio::test]
async fn observed_metadata_rejects_unknown_thread_without_rollout() {
let home = TempDir::new().expect("temp dir");
let config = test_config(home.path());
let runtime = codex_state::StateRuntime::init(
home.path().to_path_buf(),
config.default_model_provider_id.clone(),
)
.await
.expect("state db should initialize");
let store = LocalThreadStore::new(config, Some(runtime.clone()));
let uuid = Uuid::from_u128(314);
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
let err = store
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id,
patch: ThreadMetadataPatch {
preview: Some("phantom".to_string()),
..Default::default()
},
include_archived: false,
})
.await
.expect_err("metadata-only update should not create a missing thread");
assert!(matches!(
err,
ThreadStoreError::InvalidRequest { message }
if message == format!("thread not found: {thread_id}")
));
let metadata = runtime
.get_thread(thread_id)
.await
.expect("sqlite metadata read");
assert!(metadata.is_none());
}
#[tokio::test]
async fn update_thread_metadata_recreates_missing_archived_sqlite_row_as_archived() {
let home = TempDir::new().expect("temp dir");
let config = test_config(home.path());
let uuid = Uuid::from_u128(315);
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
write_archived_session_file(home.path(), "2025-01-03T19-30-00", uuid)
.expect("archived session file");
let runtime = codex_state::StateRuntime::init(
home.path().to_path_buf(),
config.default_model_provider_id.clone(),
)
.await
.expect("state db should initialize");
let store = LocalThreadStore::new(config, Some(runtime.clone()));
let thread = store
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id,
patch: ThreadMetadataPatch {
preview: Some("Archived missing sqlite row".to_string()),
..Default::default()
},
include_archived: true,
})
.await
.expect("update archived thread without sqlite row");
assert!(thread.archived_at.is_some());
assert!(
runtime
.get_thread(thread_id)
.await
.expect("get metadata")
.expect("metadata")
.archived_at
.is_some()
);
}
#[tokio::test]
async fn observed_metadata_normalizes_cwd_for_list_filters() {
let home = TempDir::new().expect("temp dir");
let config = test_config(home.path());
let runtime = codex_state::StateRuntime::init(
home.path().to_path_buf(),
config.default_model_provider_id.clone(),
)
.await
.expect("state db should initialize");
let store = LocalThreadStore::new(config, Some(runtime.clone()));
let uuid = Uuid::from_u128(316);
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
write_session_file(home.path(), "2025-01-03T20-00-00", uuid).expect("session file");
let workspace = home.path().join("workspace");
let child = workspace.join("child");
std::fs::create_dir_all(child.as_path()).expect("create workspace");
let unnormalized_cwd = child.join("..");
let normalized_cwd = codex_utils_path::normalize_for_path_comparison(workspace.as_path())
.expect("normalize cwd");
store
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id,
patch: ThreadMetadataPatch {
cwd: Some(unnormalized_cwd),
preview: Some("cwd preview".to_string()),
..Default::default()
},
include_archived: false,
})
.await
.expect("update observed cwd");
let metadata = runtime
.get_thread(thread_id)
.await
.expect("get metadata")
.expect("metadata");
assert_eq!(metadata.cwd, normalized_cwd);
let page = store
.list_threads(ListThreadsParams {
page_size: 10,
cursor: None,
sort_key: ThreadSortKey::UpdatedAt,
sort_direction: SortDirection::Desc,
allowed_sources: Vec::new(),
model_providers: Some(Vec::new()),
cwd_filters: Some(vec![workspace]),
archived: false,
search_term: None,
use_state_db_only: true,
})
.await
.expect("list threads by cwd");
assert_eq!(
page.items
.iter()
.map(|thread| thread.thread_id)
.collect::<Vec<_>>(),
vec![thread_id]
);
}
#[tokio::test]
async fn update_thread_metadata_keeps_archived_thread_archived_in_sqlite() {
let home = TempDir::new().expect("temp dir");
let config = test_config(home.path());
let uuid = Uuid::from_u128(306);
let uuid = Uuid::from_u128(307);
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
let archived_path = write_archived_session_file(home.path(), "2025-01-03T16-00-00", uuid)
.expect("archived session file");
@@ -931,7 +1422,7 @@ mod tests {
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id,
patch: ThreadMetadataPatch {
name: Some("Archived title".to_string()),
name: Some(Some("Archived title".to_string())),
..Default::default()
},
include_archived: true,
@@ -996,7 +1487,7 @@ mod tests {
.update_thread_metadata(UpdateThreadMetadataParams {
thread_id,
patch: ThreadMetadataPatch {
name: Some("Live archived title".to_string()),
name: Some(Some("Live archived title".to_string())),
..Default::default()
},
include_archived: true,
+9 -2
View File
@@ -33,7 +33,11 @@ pub trait ThreadStore: Any + Send + Sync {
/// Reopens an existing thread for live appends.
async fn resume_thread(&self, params: ResumeThreadParams) -> ThreadStoreResult<()>;
/// Appends items to a live thread.
/// Appends canonical rollout items to a live thread.
///
/// This is the raw history API. It does not infer metadata from item contents. Callers that
/// need metadata updates should call [`ThreadStore::update_thread_metadata`] with explicit
/// metadata facts prepared above the store.
async fn append_items(&self, params: AppendThreadItemsParams) -> ThreadStoreResult<()>;
/// Materializes the thread if persistence is lazy, then persists all queued items.
@@ -86,7 +90,10 @@ pub trait ThreadStore: Any + Send + Sync {
})
}
/// Applies a mutable metadata patch and returns the updated thread.
/// Applies a literal metadata patch and returns the updated thread.
///
/// Implementations should apply the supplied fields directly. Policy such as deciding whether
/// an append-derived preview should be emitted belongs above the store.
async fn update_thread_metadata(
&self,
params: UpdateThreadMetadataParams,
@@ -0,0 +1,580 @@
use std::time::Duration;
use std::time::Instant;
use chrono::DateTime;
use chrono::NaiveDateTime;
use chrono::Utc;
use codex_git_utils::collect_git_info;
use codex_git_utils::get_git_repo_root;
use codex_protocol::ThreadId;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::GitInfo;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::ThreadMemoryMode;
use codex_protocol::protocol::USER_MESSAGE_BEGIN;
use codex_protocol::protocol::UserMessageEvent;
use crate::CreateThreadParams;
use crate::GitInfoPatch;
use crate::ResumeThreadParams;
use crate::ThreadMetadataPatch;
const IMAGE_ONLY_USER_MESSAGE_PLACEHOLDER: &str = "[Image]";
#[cfg(not(test))]
const THREAD_UPDATED_AT_TOUCH_INTERVAL: Duration = Duration::from_secs(5);
#[cfg(test)]
const THREAD_UPDATED_AT_TOUCH_INTERVAL: Duration = Duration::from_millis(50);
/// Live-thread helper that derives metadata updates from canonical rollout items.
///
/// Stores receive raw history plus explicit metadata patches. This helper keeps append-derived
/// metadata observation in the live layer without owning persistence-policy filtering or making
/// `append_items` infer metadata inside a `ThreadStore` implementation.
pub(crate) struct ThreadMetadataSync {
thread_id: ThreadId,
cwd_seen: bool,
preview_seen: bool,
first_user_message_seen: bool,
title_seen: bool,
pending_update: Option<ThreadMetadataPatch>,
pending_update_generation: u64,
last_touch_persisted_at: Option<Instant>,
defer_create_update_until_history_exists: bool,
defer_resume_update_until_append: bool,
}
pub(crate) struct PendingThreadMetadataPatch {
pub(crate) patch: ThreadMetadataPatch,
generation: u64,
}
impl ThreadMetadataSync {
pub(crate) async fn for_create(params: &CreateThreadParams) -> Self {
let created_at = Utc::now();
let cwd = params.metadata.cwd.clone().unwrap_or_default();
let git_info = if get_git_repo_root(cwd.as_path()).is_some() {
collect_git_info(cwd.as_path()).await.map(|info| GitInfo {
commit_hash: info.commit_hash,
branch: info.branch,
repository_url: info.repository_url,
})
} else {
None
};
let dynamic_tools =
(!params.dynamic_tools.is_empty()).then(|| params.dynamic_tools.clone());
let update = ThreadMetadataPatch {
model_provider: Some(params.metadata.model_provider.clone()),
created_at: Some(created_at),
updated_at: Some(created_at),
source: Some(params.source.clone()),
thread_source: Some(params.thread_source),
agent_nickname: Some(params.source.get_nickname()),
agent_role: Some(params.source.get_agent_role()),
agent_path: Some(params.source.get_agent_path().map(Into::into)),
cwd: Some(cwd.clone()),
cli_version: Some(env!("CARGO_PKG_VERSION").to_string()),
git_info: git_info.map(git_info_patch_from_observation),
memory_mode: Some(params.metadata.memory_mode),
dynamic_tools,
..Default::default()
};
Self {
thread_id: params.thread_id,
cwd_seen: !cwd.as_os_str().is_empty(),
preview_seen: false,
first_user_message_seen: false,
title_seen: false,
pending_update: Some(update),
pending_update_generation: 1,
last_touch_persisted_at: None,
defer_create_update_until_history_exists: true,
defer_resume_update_until_append: false,
}
}
pub(crate) fn for_resume(params: &ResumeThreadParams) -> Self {
let mut sync = Self {
thread_id: params.thread_id,
cwd_seen: params
.metadata
.cwd
.as_ref()
.is_some_and(|cwd| !cwd.as_os_str().is_empty()),
preview_seen: false,
first_user_message_seen: false,
title_seen: false,
pending_update: None,
pending_update_generation: 0,
last_touch_persisted_at: None,
defer_create_update_until_history_exists: false,
defer_resume_update_until_append: false,
};
if let Some(history) = params.history.as_deref() {
let update = sync.observe_resume_history(history);
sync.merge_pending_update(update);
sync.defer_resume_update_until_append = sync.pending_update.is_some();
}
sync
}
pub(crate) fn take_pending_update(&self) -> Option<PendingThreadMetadataPatch> {
self.pending_update
.clone()
.map(|patch| PendingThreadMetadataPatch {
patch,
generation: self.pending_update_generation,
})
}
pub(crate) fn take_pending_update_for_existing_history(
&self,
) -> Option<PendingThreadMetadataPatch> {
if self.defer_create_update_until_history_exists {
return None;
}
if self.defer_resume_update_until_append {
return None;
}
self.take_pending_update()
}
pub(crate) fn mark_pending_update_applied(&mut self, update: &PendingThreadMetadataPatch) {
if self.pending_update_generation == update.generation {
self.pending_update = None;
}
if update.patch.updated_at.is_some() {
self.last_touch_persisted_at = Some(Instant::now());
}
}
pub(crate) fn observe_appended_items(
&mut self,
items: &[RolloutItem],
) -> Option<PendingThreadMetadataPatch> {
self.defer_create_update_until_history_exists = false;
self.defer_resume_update_until_append = false;
let affects_metadata = items
.iter()
.any(codex_state::rollout_item_affects_thread_metadata);
let update = if affects_metadata {
self.observe_items(items)?
} else {
thread_updated_at_touch()
};
self.merge_pending_update(Some(update));
if !affects_metadata
&& !self
.pending_update
.as_ref()
.is_some_and(update_has_metadata_facts)
&& self.last_touch_persisted_at.is_some_and(|last_touch| {
Instant::now().duration_since(last_touch) < THREAD_UPDATED_AT_TOUCH_INTERVAL
})
{
return None;
}
self.take_pending_update()
}
fn observe_items(&mut self, items: &[RolloutItem]) -> Option<ThreadMetadataPatch> {
self.observe_items_with_update(
items,
ThreadMetadataPatch {
updated_at: Some(Utc::now()),
..Default::default()
},
)
}
fn observe_resume_history(&mut self, items: &[RolloutItem]) -> Option<ThreadMetadataPatch> {
self.observe_items_with_update(items, ThreadMetadataPatch::default())
}
fn observe_items_with_update(
&mut self,
items: &[RolloutItem],
mut update: ThreadMetadataPatch,
) -> Option<ThreadMetadataPatch> {
if items.is_empty() {
return None;
}
for item in items {
match item {
RolloutItem::SessionMeta(meta_line) if meta_line.meta.id == self.thread_id => {
update.created_at = parse_session_timestamp(meta_line.meta.timestamp.as_str());
update.source = Some(meta_line.meta.source.clone());
update.thread_source = Some(meta_line.meta.thread_source);
update.agent_nickname = Some(meta_line.meta.agent_nickname.clone());
update.agent_role = Some(meta_line.meta.agent_role.clone());
update.agent_path = Some(meta_line.meta.agent_path.clone());
if let Some(model_provider) = meta_line.meta.model_provider.clone()
&& !model_provider.is_empty()
{
update.model_provider = Some(model_provider);
}
if !meta_line.meta.cli_version.is_empty() {
update.cli_version = Some(meta_line.meta.cli_version.clone());
}
if !meta_line.meta.cwd.as_os_str().is_empty() {
self.cwd_seen = true;
update.cwd = Some(meta_line.meta.cwd.clone());
}
if let Some(git_info) = meta_line.git.clone() {
update.git_info = Some(git_info_patch_from_observation(git_info));
}
if let Some(memory_mode) = meta_line.meta.memory_mode.as_deref()
&& let Some(memory_mode) = parse_memory_mode(memory_mode)
{
update.memory_mode = Some(memory_mode);
}
if let Some(dynamic_tools) = meta_line.meta.dynamic_tools.clone() {
update.dynamic_tools = Some(dynamic_tools);
}
}
RolloutItem::TurnContext(turn_ctx) => {
if !self.cwd_seen && !turn_ctx.cwd.as_os_str().is_empty() {
self.cwd_seen = true;
update.cwd = Some(turn_ctx.cwd.clone());
}
update.model = Some(turn_ctx.model.clone());
update.reasoning_effort = turn_ctx.effort;
update.approval_mode = Some(turn_ctx.approval_policy);
update.sandbox_policy = Some(turn_ctx.sandbox_policy.clone());
}
RolloutItem::EventMsg(EventMsg::UserMessage(user)) => {
if let Some(preview) = user_message_preview(user) {
if !self.first_user_message_seen {
self.first_user_message_seen = true;
update.first_user_message = Some(preview.clone());
}
if !self.preview_seen {
self.preview_seen = true;
update.preview = Some(preview);
}
}
if !self.title_seen {
let title = strip_user_message_prefix(user.message.as_str());
if !title.is_empty() {
self.title_seen = true;
update.title = Some(title.to_string());
}
}
}
RolloutItem::EventMsg(EventMsg::TokenCount(token_count)) => {
if let Some(info) = token_count.info.as_ref() {
update.token_usage = Some(info.total_token_usage.clone());
}
}
RolloutItem::EventMsg(EventMsg::ThreadGoalUpdated(event)) => {
if !self.preview_seen {
let objective = event.goal.objective.trim();
if !objective.is_empty() {
self.preview_seen = true;
update.preview = Some(objective.to_string());
}
}
}
RolloutItem::SessionMeta(_)
| RolloutItem::EventMsg(_)
| RolloutItem::ResponseItem(_)
| RolloutItem::Compacted(_) => {}
}
}
Some(update)
}
fn merge_pending_update(&mut self, update: Option<ThreadMetadataPatch>) {
let Some(update) = update else {
return;
};
match self.pending_update.as_mut() {
Some(pending_update) => pending_update.merge(update),
None => self.pending_update = Some(update),
}
self.pending_update_generation = self.pending_update_generation.wrapping_add(1);
}
}
fn parse_memory_mode(value: &str) -> Option<ThreadMemoryMode> {
match value {
"enabled" => Some(ThreadMemoryMode::Enabled),
"disabled" => Some(ThreadMemoryMode::Disabled),
_ => None,
}
}
fn parse_session_timestamp(value: &str) -> Option<DateTime<Utc>> {
DateTime::parse_from_rfc3339(value)
.map(|timestamp| timestamp.with_timezone(&Utc))
.or_else(|_| {
NaiveDateTime::parse_from_str(value, "%Y-%m-%dT%H-%M-%S")
.map(|timestamp| DateTime::from_naive_utc_and_offset(timestamp, Utc))
})
.ok()
}
fn strip_user_message_prefix(text: &str) -> &str {
match text.find(USER_MESSAGE_BEGIN) {
Some(idx) => text[idx + USER_MESSAGE_BEGIN.len()..].trim(),
None => text.trim(),
}
}
fn user_message_preview(user: &UserMessageEvent) -> Option<String> {
let message = strip_user_message_prefix(user.message.as_str());
if !message.is_empty() {
return Some(message.to_string());
}
if user
.images
.as_ref()
.is_some_and(|images| !images.is_empty())
|| !user.local_images.is_empty()
{
return Some(IMAGE_ONLY_USER_MESSAGE_PLACEHOLDER.to_string());
}
None
}
fn thread_updated_at_touch() -> ThreadMetadataPatch {
ThreadMetadataPatch {
updated_at: Some(Utc::now()),
..Default::default()
}
}
fn update_has_metadata_facts(update: &ThreadMetadataPatch) -> bool {
update.rollout_path.is_some()
|| update.preview.is_some()
|| update.title.is_some()
|| update.model_provider.is_some()
|| update.model.is_some()
|| update.reasoning_effort.is_some()
|| update.created_at.is_some()
|| update.source.is_some()
|| update.thread_source.is_some()
|| update.agent_nickname.is_some()
|| update.agent_role.is_some()
|| update.agent_path.is_some()
|| update.cwd.is_some()
|| update.cli_version.is_some()
|| update.approval_mode.is_some()
|| update.sandbox_policy.is_some()
|| update.token_usage.is_some()
|| update.first_user_message.is_some()
|| update.git_info.is_some()
|| update.memory_mode.is_some()
|| update.dynamic_tools.is_some()
}
fn git_info_patch_from_observation(git_info: GitInfo) -> GitInfoPatch {
GitInfoPatch {
sha: git_info.commit_hash.map(|sha| Some(sha.0)),
branch: git_info.branch.map(Some),
origin_url: git_info.repository_url.map(Some),
}
}
#[cfg(test)]
mod tests {
use codex_protocol::protocol::CompactedItem;
use codex_protocol::protocol::SessionMeta;
use codex_protocol::protocol::SessionMetaLine;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::ThreadGoal;
use codex_protocol::protocol::ThreadGoalStatus;
use codex_protocol::protocol::ThreadGoalUpdatedEvent;
use codex_protocol::protocol::UserMessageEvent;
use pretty_assertions::assert_eq;
use super::*;
use crate::ThreadEventPersistenceMode;
use crate::ThreadPersistenceMetadata;
#[test]
fn resume_history_keeps_derived_metadata_pending_until_applied() {
let thread_id = ThreadId::new();
let mut sync = ThreadMetadataSync::for_resume(&resume_params(
thread_id,
vec![
RolloutItem::SessionMeta(session_meta(thread_id)),
RolloutItem::EventMsg(EventMsg::UserMessage(user_message("hello metadata"))),
],
));
let update = sync.take_pending_update().expect("pending metadata update");
assert_eq!(
update
.patch
.created_at
.expect("created_at should come from session metadata")
.to_rfc3339(),
"2025-01-03T12:00:00+00:00"
);
assert_eq!(update.patch.preview.as_deref(), Some("hello metadata"));
assert_eq!(update.patch.title.as_deref(), Some("hello metadata"));
assert_eq!(
update.patch.first_user_message.as_deref(),
Some("hello metadata")
);
assert_eq!(update.patch.updated_at, None);
assert!(
sync.take_pending_update().is_some(),
"taking the pending update should not drop retry state"
);
sync.mark_pending_update_applied(&update);
assert!(sync.take_pending_update().is_none());
}
#[test]
fn goal_update_sets_preview_without_overriding_existing_preview() {
let thread_id = ThreadId::new();
let sync = ThreadMetadataSync::for_resume(&resume_params(
thread_id,
vec![
RolloutItem::EventMsg(EventMsg::ThreadGoalUpdated(goal_update(
thread_id,
"ship the refactor",
))),
RolloutItem::EventMsg(EventMsg::UserMessage(user_message("first user text"))),
],
));
let update = sync.take_pending_update().expect("pending metadata update");
assert_eq!(update.patch.preview.as_deref(), Some("ship the refactor"));
assert_eq!(
update.patch.first_user_message.as_deref(),
Some("first user text")
);
assert_eq!(update.patch.title.as_deref(), Some("first user text"));
}
#[test]
fn later_user_messages_do_not_emit_existing_preview_fields() {
let thread_id = ThreadId::new();
let mut sync = ThreadMetadataSync::for_resume(&resume_params(
thread_id,
vec![RolloutItem::EventMsg(EventMsg::UserMessage(user_message(
"first user text",
)))],
));
let pending = sync.take_pending_update().expect("pending resume metadata");
sync.mark_pending_update_applied(&pending);
let update = sync
.observe_appended_items(&[RolloutItem::EventMsg(EventMsg::UserMessage(user_message(
"later user text",
)))])
.expect("updated_at touch");
assert_eq!(update.patch.preview, None);
assert_eq!(update.patch.title, None);
assert_eq!(update.patch.first_user_message, None);
assert!(update.patch.updated_at.is_some());
}
#[test]
fn metadata_irrelevant_items_coalesce_updated_at_touches() {
let thread_id = ThreadId::new();
let mut sync = ThreadMetadataSync::for_resume(&resume_params(thread_id, Vec::new()));
let item = RolloutItem::Compacted(CompactedItem {
message: "compacted".to_string(),
replacement_history: None,
});
let first = sync
.observe_appended_items(std::slice::from_ref(&item))
.expect("first touch should apply immediately");
assert!(first.patch.updated_at.is_some());
sync.mark_pending_update_applied(&first);
assert!(
sync.observe_appended_items(std::slice::from_ref(&item))
.is_none(),
"second touch inside the coalescing window should wait for a barrier"
);
assert!(
sync.take_pending_update().is_some(),
"coalesced touches still flush at the next barrier"
);
}
#[test]
fn resume_history_waits_for_append_before_flushing_metadata() {
let thread_id = ThreadId::new();
let mut sync = ThreadMetadataSync::for_resume(&resume_params(
thread_id,
vec![
RolloutItem::SessionMeta(session_meta(thread_id)),
RolloutItem::EventMsg(EventMsg::UserMessage(user_message("hello metadata"))),
],
));
assert!(
sync.take_pending_update_for_existing_history().is_none(),
"resume-only metadata should not flush without a new append"
);
assert!(
sync.observe_appended_items(&[RolloutItem::EventMsg(EventMsg::UserMessage(
user_message("new append"),
))])
.is_some(),
"the first append should flush resume metadata together with append metadata"
);
}
fn resume_params(thread_id: ThreadId, history: Vec<RolloutItem>) -> ResumeThreadParams {
ResumeThreadParams {
thread_id,
rollout_path: None,
history: Some(history),
include_archived: false,
metadata: ThreadPersistenceMetadata {
cwd: None,
model_provider: "test-provider".to_string(),
memory_mode: ThreadMemoryMode::Enabled,
},
event_persistence_mode: ThreadEventPersistenceMode::Limited,
}
}
fn user_message(message: &str) -> UserMessageEvent {
UserMessageEvent {
message: message.to_string(),
images: None,
local_images: Vec::new(),
text_elements: Vec::new(),
}
}
fn session_meta(thread_id: ThreadId) -> SessionMetaLine {
SessionMetaLine {
meta: SessionMeta {
id: thread_id,
timestamp: "2025-01-03T12:00:00Z".to_string(),
source: SessionSource::Exec,
..Default::default()
},
git: None,
}
}
fn goal_update(thread_id: ThreadId, objective: &str) -> ThreadGoalUpdatedEvent {
ThreadGoalUpdatedEvent {
thread_id,
turn_id: None,
goal: ThreadGoal {
thread_id,
objective: objective.to_string(),
status: ThreadGoalStatus::Active,
token_budget: None,
tokens_used: 0,
time_used_seconds: 0,
created_at: 0,
updated_at: 0,
},
}
}
}
+358 -11
View File
@@ -15,7 +15,32 @@ use codex_protocol::protocol::ThreadMemoryMode as MemoryMode;
use codex_protocol::protocol::ThreadSource;
use codex_protocol::protocol::TokenUsage;
use serde::Deserialize;
use serde::Deserializer;
use serde::Serialize;
use serde::Serializer;
mod optional_option {
use super::*;
pub fn serialize<T, S>(value: &Option<Option<T>>, serializer: S) -> Result<S::Ok, S::Error>
where
T: Serialize,
S: Serializer,
{
match value {
Some(value) => value.serialize(serializer),
None => serializer.serialize_none(),
}
}
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<Option<Option<T>>, D::Error>
where
T: Deserialize<'de>,
D: Deserializer<'de>,
{
Option::<T>::deserialize(deserializer).map(Some)
}
}
/// Controls how many event variants should be persisted for future replay.
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
@@ -348,32 +373,241 @@ pub struct StoredThread {
}
/// Optional field patch where omission leaves a value unchanged and `Some(None)` clears it.
pub type OptionalStringPatch = Option<Option<String>>;
pub type ClearableField<T> = Option<Option<T>>;
/// Patch for thread Git metadata.
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct GitInfoPatch {
/// Replacement commit SHA, clear request, or no-op.
pub sha: OptionalStringPatch,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "optional_option"
)]
pub sha: ClearableField<String>,
/// Replacement branch name, clear request, or no-op.
pub branch: OptionalStringPatch,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "optional_option"
)]
pub branch: ClearableField<String>,
/// Replacement origin URL, clear request, or no-op.
pub origin_url: OptionalStringPatch,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "optional_option"
)]
pub origin_url: ClearableField<String>,
}
/// Patch for mutable thread metadata.
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
impl GitInfoPatch {
/// Merges another patch into this one using field-presence semantics.
///
/// Omitted fields in `next` leave the current patch unchanged. Present fields replace the
/// current value, including clear requests like `Some(None)`.
pub fn merge(&mut self, next: Self) {
if next.sha.is_some() {
self.sha = next.sha;
}
if next.branch.is_some() {
self.branch = next.branch;
}
if next.origin_url.is_some() {
self.origin_url = next.origin_url;
}
}
}
/// Patch for thread metadata.
///
/// Every field is literal: `None` leaves that field unchanged, while `Some`
/// applies the supplied value. Fields whose value may itself be cleared use an
/// inner `Option`, where `Some(None)` clears the field.
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct ThreadMetadataPatch {
/// Replacement user-facing thread name.
pub name: Option<String>,
/// Replacement thread memory behavior.
pub memory_mode: Option<MemoryMode>,
/// Optional Git metadata patch.
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "optional_option"
)]
pub name: ClearableField<String>,
/// Known local rollout path for stores that expose one.
pub rollout_path: Option<PathBuf>,
/// Best available preview text for discovery/listing.
pub preview: Option<String>,
/// Best-effort title derived from history.
pub title: Option<String>,
/// Model provider associated with the thread.
pub model_provider: Option<String>,
/// Latest observed model.
pub model: Option<String>,
/// Latest observed reasoning effort.
pub reasoning_effort: Option<ReasoningEffort>,
/// Creation timestamp when known.
pub created_at: Option<DateTime<Utc>>,
/// Last update timestamp for this metadata observation.
pub updated_at: Option<DateTime<Utc>>,
/// Session source.
pub source: Option<SessionSource>,
/// Optional analytics source classification.
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "optional_option"
)]
pub thread_source: ClearableField<ThreadSource>,
/// Optional agent nickname.
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "optional_option"
)]
pub agent_nickname: ClearableField<String>,
/// Optional agent role.
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "optional_option"
)]
pub agent_role: ClearableField<String>,
/// Optional canonical agent path.
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "optional_option"
)]
pub agent_path: ClearableField<String>,
/// Working directory.
pub cwd: Option<PathBuf>,
/// CLI version that created the thread.
pub cli_version: Option<String>,
/// Approval mode.
pub approval_mode: Option<AskForApproval>,
/// Sandbox policy.
pub sandbox_policy: Option<SandboxPolicy>,
/// Last observed token usage.
pub token_usage: Option<TokenUsage>,
/// First user message observed for this thread.
pub first_user_message: Option<String>,
/// Git metadata patch.
pub git_info: Option<GitInfoPatch>,
/// Thread memory behavior.
pub memory_mode: Option<MemoryMode>,
/// Dynamic tools available to this thread.
pub dynamic_tools: Option<Vec<DynamicToolSpec>>,
}
impl ThreadMetadataPatch {
/// Merges another patch into this one using field-presence semantics.
///
/// Omitted fields in `next` leave the current patch unchanged. Present fields replace the
/// current value, including clear requests like `Some(None)`. Nested patches use the same
/// semantics.
pub fn merge(&mut self, next: Self) {
if next.name.is_some() {
self.name = next.name;
}
if next.rollout_path.is_some() {
self.rollout_path = next.rollout_path;
}
if next.preview.is_some() {
self.preview = next.preview;
}
if next.title.is_some() {
self.title = next.title;
}
if next.model_provider.is_some() {
self.model_provider = next.model_provider;
}
if next.model.is_some() {
self.model = next.model;
}
if next.reasoning_effort.is_some() {
self.reasoning_effort = next.reasoning_effort;
}
if next.created_at.is_some() {
self.created_at = next.created_at;
}
if next.updated_at.is_some() {
self.updated_at = next.updated_at;
}
if next.source.is_some() {
self.source = next.source;
}
if next.thread_source.is_some() {
self.thread_source = next.thread_source;
}
if next.agent_nickname.is_some() {
self.agent_nickname = next.agent_nickname;
}
if next.agent_role.is_some() {
self.agent_role = next.agent_role;
}
if next.agent_path.is_some() {
self.agent_path = next.agent_path;
}
if next.cwd.is_some() {
self.cwd = next.cwd;
}
if next.cli_version.is_some() {
self.cli_version = next.cli_version;
}
if next.approval_mode.is_some() {
self.approval_mode = next.approval_mode;
}
if next.sandbox_policy.is_some() {
self.sandbox_policy = next.sandbox_policy;
}
if next.token_usage.is_some() {
self.token_usage = next.token_usage;
}
if next.first_user_message.is_some() {
self.first_user_message = next.first_user_message;
}
if let Some(git_info) = next.git_info {
self.git_info
.get_or_insert_with(GitInfoPatch::default)
.merge(git_info);
}
if next.memory_mode.is_some() {
self.memory_mode = next.memory_mode;
}
if next.dynamic_tools.is_some() {
self.dynamic_tools = next.dynamic_tools;
}
}
pub fn is_empty(&self) -> bool {
self.name.is_none()
&& self.rollout_path.is_none()
&& self.preview.is_none()
&& self.title.is_none()
&& self.model_provider.is_none()
&& self.model.is_none()
&& self.reasoning_effort.is_none()
&& self.created_at.is_none()
&& self.updated_at.is_none()
&& self.source.is_none()
&& self.thread_source.is_none()
&& self.agent_nickname.is_none()
&& self.agent_role.is_none()
&& self.agent_path.is_none()
&& self.cwd.is_none()
&& self.cli_version.is_none()
&& self.approval_mode.is_none()
&& self.sandbox_policy.is_none()
&& self.token_usage.is_none()
&& self.first_user_message.is_none()
&& self.git_info.is_none()
&& self.memory_mode.is_none()
&& self.dynamic_tools.is_none()
}
}
/// Parameters for patching mutable thread metadata.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct UpdateThreadMetadataParams {
/// Thread id to update.
pub thread_id: ThreadId,
@@ -389,3 +623,116 @@ pub struct ArchiveThreadParams {
/// Thread id to archive or unarchive.
pub thread_id: ThreadId,
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use serde_json::json;
use super::*;
#[test]
fn thread_metadata_patch_round_trips_optional_clears() {
let patch = ThreadMetadataPatch {
name: Some(None),
thread_source: Some(None),
agent_nickname: Some(None),
agent_role: Some(None),
agent_path: Some(None),
..Default::default()
};
let value = serde_json::to_value(&patch).expect("serialize patch");
assert_eq!(value["name"], json!(null));
assert_eq!(value["thread_source"], json!(null));
assert_eq!(value["agent_nickname"], json!(null));
assert_eq!(value["agent_role"], json!(null));
assert_eq!(value["agent_path"], json!(null));
let decoded: ThreadMetadataPatch =
serde_json::from_value(value).expect("deserialize patch");
assert_eq!(decoded.name, Some(None));
assert_eq!(decoded.thread_source, Some(None));
assert_eq!(decoded.agent_nickname, Some(None));
assert_eq!(decoded.agent_role, Some(None));
assert_eq!(decoded.agent_path, Some(None));
}
#[test]
fn git_info_patch_round_trips_optional_clears() {
let patch = ThreadMetadataPatch {
git_info: Some(GitInfoPatch {
sha: None,
branch: Some(Some("main".to_string())),
origin_url: Some(None),
}),
..Default::default()
};
let value = serde_json::to_value(&patch).expect("serialize patch");
assert_eq!(
value["git_info"],
json!({
"branch": "main",
"origin_url": null,
})
);
let decoded: ThreadMetadataPatch =
serde_json::from_value(value).expect("deserialize patch");
assert_eq!(
decoded.git_info,
Some(GitInfoPatch {
sha: None,
branch: Some(Some("main".to_string())),
origin_url: Some(None),
})
);
}
#[test]
fn thread_metadata_patch_accepts_missing_fields() {
let decoded: ThreadMetadataPatch =
serde_json::from_value(json!({})).expect("deserialize legacy patch");
assert!(decoded.is_empty());
}
#[test]
fn thread_metadata_patch_merge_uses_presence_semantics() {
let mut current = ThreadMetadataPatch {
name: Some(Some("old name".to_string())),
preview: Some("old preview".to_string()),
git_info: Some(GitInfoPatch {
sha: Some(Some("abc123".to_string())),
branch: Some(Some("main".to_string())),
origin_url: None,
}),
..Default::default()
};
current.merge(ThreadMetadataPatch {
name: Some(None),
preview: None,
title: Some("new title".to_string()),
git_info: Some(GitInfoPatch {
sha: None,
branch: Some(Some("feature".to_string())),
origin_url: Some(None),
}),
..Default::default()
});
assert_eq!(current.name, Some(None));
assert_eq!(current.preview.as_deref(), Some("old preview"));
assert_eq!(current.title.as_deref(), Some("new title"));
assert_eq!(
current.git_info,
Some(GitInfoPatch {
sha: Some(Some("abc123".to_string())),
branch: Some(Some("feature".to_string())),
origin_url: Some(None),
})
);
}
}