mirror of
https://github.com/pchuan98/codex.git
synced 2026-07-01 00:31:56 +08:00
Unify thread metadata updates above store (#22236)
- make ThreadStore::update_thread_metadata accept a broad range of metadata patches - keep ThreadStore::append_items as raw canonical history append (no metadata side effects) - in the local store, write these metadata updates to a combination of sqlite and rollout jsonl files for backwards-compat. It special cases which fields need to go into jsonl vs sqlite vs whatever, confining the awkwardness to just this implementation - in remote stores we can simply persist the metadata directly to a database, no special casing required. - move the "implicit metadata updates triggered by appending rollout items" from the RolloutRecorder (which is local-threadstore-specific) to the LiveThread layer above the ThreadStore, inside of a private helper utility called ThreadMetadataSync. LiveThread calls ThreadStore append_items and update_metadata separately. - Add a generic update metadata method to ThreadManager that works on both live threads and "cold" threads - Call that ThreadManager method from app server code, so app server doesn't need to worry about whether the thread is live or not
This commit is contained in:
Generated
+1
@@ -3687,6 +3687,7 @@ dependencies = [
|
||||
"codex-protocol",
|
||||
"codex-rollout",
|
||||
"codex-state",
|
||||
"codex-utils-path",
|
||||
"pretty_assertions",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
||||
@@ -401,7 +401,6 @@ use codex_thread_store::ThreadMetadataPatch as StoreThreadMetadataPatch;
|
||||
use codex_thread_store::ThreadSortKey as StoreThreadSortKey;
|
||||
use codex_thread_store::ThreadStore;
|
||||
use codex_thread_store::ThreadStoreError;
|
||||
use codex_thread_store::UpdateThreadMetadataParams as StoreUpdateThreadMetadataParams;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use codex_utils_pty::DEFAULT_OUTPUT_BYTES_CAP;
|
||||
use std::collections::HashMap;
|
||||
|
||||
@@ -323,7 +323,7 @@ impl ExternalAgentConfigRequestProcessor {
|
||||
.thread
|
||||
.update_thread_metadata(
|
||||
ThreadMetadataPatch {
|
||||
name: Some(name),
|
||||
name: Some(Some(name)),
|
||||
..Default::default()
|
||||
},
|
||||
/*include_archived*/ false,
|
||||
|
||||
@@ -1430,17 +1430,17 @@ impl ThreadRequestProcessor {
|
||||
};
|
||||
|
||||
let _thread_list_state_permit = self.acquire_thread_list_state_permit().await?;
|
||||
self.thread_store
|
||||
.update_thread_metadata(StoreUpdateThreadMetadataParams {
|
||||
self.thread_manager
|
||||
.update_thread_metadata(
|
||||
thread_id,
|
||||
patch: StoreThreadMetadataPatch {
|
||||
name: Some(name.clone()),
|
||||
StoreThreadMetadataPatch {
|
||||
name: Some(Some(name.clone())),
|
||||
..Default::default()
|
||||
},
|
||||
include_archived: false,
|
||||
})
|
||||
/*include_archived*/ false,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| thread_store_write_error("set thread name", err))?;
|
||||
.map_err(|err| core_thread_write_error("set thread name", err))?;
|
||||
|
||||
Ok((
|
||||
ThreadSetNameResponse {},
|
||||
@@ -1459,33 +1459,17 @@ impl ThreadRequestProcessor {
|
||||
let thread_id = ThreadId::from_string(&thread_id)
|
||||
.map_err(|err| invalid_request(format!("invalid thread id: {err}")))?;
|
||||
|
||||
if let Ok(thread) = self.thread_manager.get_thread(thread_id).await {
|
||||
if thread.config_snapshot().await.ephemeral {
|
||||
return Err(invalid_request(format!(
|
||||
"ephemeral thread does not support memory mode updates: {thread_id}"
|
||||
)));
|
||||
}
|
||||
|
||||
thread
|
||||
.set_thread_memory_mode(mode.to_core())
|
||||
.await
|
||||
.map_err(|err| {
|
||||
internal_error(format!("failed to set thread memory mode: {err}"))
|
||||
})?;
|
||||
return Ok(ThreadMemoryModeSetResponse {});
|
||||
}
|
||||
|
||||
self.thread_store
|
||||
.update_thread_metadata(StoreUpdateThreadMetadataParams {
|
||||
self.thread_manager
|
||||
.update_thread_metadata(
|
||||
thread_id,
|
||||
patch: StoreThreadMetadataPatch {
|
||||
StoreThreadMetadataPatch {
|
||||
memory_mode: Some(mode.to_core()),
|
||||
..Default::default()
|
||||
},
|
||||
include_archived: false,
|
||||
})
|
||||
/*include_archived*/ false,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| thread_store_write_error("set thread memory mode", err))?;
|
||||
.map_err(|err| core_thread_write_error("set thread memory mode", err))?;
|
||||
|
||||
Ok(ThreadMemoryModeSetResponse {})
|
||||
}
|
||||
@@ -1551,35 +1535,19 @@ impl ThreadRequestProcessor {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let loaded_thread = self.thread_manager.get_thread(thread_uuid).await.ok();
|
||||
let updated_thread = {
|
||||
let _thread_list_state_permit = self.acquire_thread_list_state_permit().await?;
|
||||
if let Some(loaded_thread) = loaded_thread.as_ref() {
|
||||
if loaded_thread.config_snapshot().await.ephemeral {
|
||||
return Err(invalid_request(format!(
|
||||
"ephemeral thread does not support metadata updates: {thread_id}"
|
||||
)));
|
||||
}
|
||||
loaded_thread
|
||||
.update_thread_metadata(patch, /*include_archived*/ true)
|
||||
.await
|
||||
} else {
|
||||
self.thread_store
|
||||
.update_thread_metadata(StoreUpdateThreadMetadataParams {
|
||||
thread_id: thread_uuid,
|
||||
patch,
|
||||
include_archived: true,
|
||||
})
|
||||
.await
|
||||
}
|
||||
.map_err(|err| thread_store_write_error("update thread metadata", err))?
|
||||
self.thread_manager
|
||||
.update_thread_metadata(thread_uuid, patch, /*include_archived*/ true)
|
||||
.await
|
||||
.map_err(|err| core_thread_write_error("update thread metadata", err))?
|
||||
};
|
||||
let (mut thread, _) = thread_from_stored_thread(
|
||||
updated_thread,
|
||||
self.config.model_provider_id.as_str(),
|
||||
&self.config.cwd,
|
||||
);
|
||||
if let Some(loaded_thread) = loaded_thread.as_ref() {
|
||||
if let Ok(loaded_thread) = self.thread_manager.get_thread(thread_uuid).await {
|
||||
thread.session_id = loaded_thread.session_configured().session_id.to_string();
|
||||
}
|
||||
self.attach_thread_name(thread_uuid, &mut thread).await;
|
||||
@@ -3707,15 +3675,13 @@ fn conversation_summary_rollout_path_read_error(
|
||||
}
|
||||
}
|
||||
|
||||
fn thread_store_write_error(operation: &str, err: ThreadStoreError) -> JSONRPCErrorError {
|
||||
fn core_thread_write_error(operation: &str, err: CodexErr) -> JSONRPCErrorError {
|
||||
match err {
|
||||
ThreadStoreError::ThreadNotFound { thread_id } => {
|
||||
CodexErr::ThreadNotFound(thread_id) => {
|
||||
invalid_request(format!("thread not found: {thread_id}"))
|
||||
}
|
||||
ThreadStoreError::InvalidRequest { message } => invalid_request(message),
|
||||
ThreadStoreError::Unsupported { operation } => {
|
||||
unsupported_thread_store_operation(operation)
|
||||
}
|
||||
CodexErr::InvalidRequest(message) => invalid_request(message),
|
||||
CodexErr::UnsupportedOperation(message) => method_not_found(message),
|
||||
err => internal_error(format!("failed to {operation}: {err}")),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1304,7 +1304,7 @@ async fn seed_pathless_store_thread(
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id,
|
||||
patch: ThreadMetadataPatch {
|
||||
name: Some("named pathless thread".to_string()),
|
||||
name: Some(Some("named pathless thread".to_string())),
|
||||
..Default::default()
|
||||
},
|
||||
include_archived: true,
|
||||
|
||||
@@ -226,7 +226,7 @@ async fn thread_unarchive_preserves_pathless_store_metadata() -> Result<()> {
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id,
|
||||
patch: ThreadMetadataPatch {
|
||||
name: Some("named pathless thread".to_string()),
|
||||
name: Some(Some("named pathless thread".to_string())),
|
||||
..Default::default()
|
||||
},
|
||||
include_archived: true,
|
||||
|
||||
@@ -2846,12 +2846,21 @@ impl Session {
|
||||
&self,
|
||||
turn_context: &TurnContext,
|
||||
token_usage: Option<&TokenUsage>,
|
||||
) {
|
||||
self.record_token_usage_info(turn_context, token_usage)
|
||||
.await;
|
||||
self.send_token_count_event(turn_context).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn record_token_usage_info(
|
||||
&self,
|
||||
turn_context: &TurnContext,
|
||||
token_usage: Option<&TokenUsage>,
|
||||
) {
|
||||
if let Some(token_usage) = token_usage {
|
||||
let mut state = self.state.lock().await;
|
||||
state.update_token_info_from_usage(token_usage, turn_context.model_context_window());
|
||||
}
|
||||
self.send_token_count_event(turn_context).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn recompute_token_usage(&self, turn_context: &TurnContext) {
|
||||
@@ -2892,11 +2901,15 @@ impl Session {
|
||||
turn_context: &TurnContext,
|
||||
new_rate_limits: RateLimitSnapshot,
|
||||
) {
|
||||
self.record_rate_limits_info(new_rate_limits).await;
|
||||
self.send_token_count_event(turn_context).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn record_rate_limits_info(&self, new_rate_limits: RateLimitSnapshot) {
|
||||
{
|
||||
let mut state = self.state.lock().await;
|
||||
state.set_rate_limits(new_rate_limits);
|
||||
}
|
||||
self.send_token_count_event(turn_context).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn mcp_dependency_prompted(&self) -> HashSet<String> {
|
||||
@@ -2927,7 +2940,7 @@ impl Session {
|
||||
state.set_server_reasoning_included(included);
|
||||
}
|
||||
|
||||
async fn send_token_count_event(&self, turn_context: &TurnContext) {
|
||||
pub(crate) async fn send_token_count_event(&self, turn_context: &TurnContext) {
|
||||
let (info, rate_limits) = {
|
||||
let state = self.state.lock().await;
|
||||
state.token_info_and_rate_limits()
|
||||
|
||||
@@ -1872,6 +1872,7 @@ async fn try_run_sampling_request(
|
||||
Box<dyn ToolArgumentDiffConsumer>,
|
||||
)> = None;
|
||||
let mut should_emit_turn_diff = false;
|
||||
let mut should_emit_token_count = false;
|
||||
let reasoning_effort = turn_context.effective_reasoning_effort_for_tracing();
|
||||
let plan_mode = turn_context.collaboration_mode.mode == ModeKind::Plan;
|
||||
let mut assistant_message_stream_parsers = AssistantMessageStreamParsers::new(plan_mode);
|
||||
@@ -2098,7 +2099,8 @@ async fn try_run_sampling_request(
|
||||
ResponseEvent::RateLimits(snapshot) => {
|
||||
// Update internal state with latest rate limits, but defer sending until
|
||||
// token usage is available to avoid duplicate TokenCount events.
|
||||
sess.update_rate_limits(&turn_context, snapshot).await;
|
||||
sess.record_rate_limits_info(snapshot).await;
|
||||
should_emit_token_count = true;
|
||||
}
|
||||
ResponseEvent::ModelsEtag(etag) => {
|
||||
// Update internal state with latest models etag
|
||||
@@ -2116,8 +2118,9 @@ async fn try_run_sampling_request(
|
||||
&mut assistant_message_stream_parsers,
|
||||
)
|
||||
.await;
|
||||
sess.update_token_usage_info(&turn_context, token_usage.as_ref())
|
||||
sess.record_token_usage_info(&turn_context, token_usage.as_ref())
|
||||
.await;
|
||||
should_emit_token_count = true;
|
||||
should_emit_turn_diff = true;
|
||||
if let Some(false) = end_turn {
|
||||
needs_follow_up = true;
|
||||
@@ -2245,6 +2248,14 @@ async fn try_run_sampling_request(
|
||||
|
||||
drain_in_flight(&mut in_flight, sess.clone(), turn_context.clone()).await?;
|
||||
|
||||
if should_emit_token_count {
|
||||
// A tool call such as request_user_input can intentionally pause the turn. Emit token
|
||||
// counts only after pending tools resolve so clients do not see progress events while the
|
||||
// turn is waiting on the user. This also needs to happen before returning cancellation so
|
||||
// token usage already recorded from the completed response is still persisted.
|
||||
sess.send_token_count_event(&turn_context).await;
|
||||
}
|
||||
|
||||
if cancellation_token.is_cancelled() {
|
||||
return Err(CodexErr::TurnAborted);
|
||||
}
|
||||
|
||||
@@ -58,8 +58,10 @@ use codex_thread_store::LocalThreadStoreConfig;
|
||||
use codex_thread_store::ReadThreadByRolloutPathParams;
|
||||
use codex_thread_store::ReadThreadParams;
|
||||
use codex_thread_store::StoredThread;
|
||||
use codex_thread_store::ThreadMetadataPatch;
|
||||
use codex_thread_store::ThreadStore;
|
||||
use codex_thread_store::ThreadStoreError;
|
||||
use codex_thread_store::UpdateThreadMetadataParams;
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::FuturesUnordered;
|
||||
@@ -456,6 +458,44 @@ impl ThreadManager {
|
||||
self.state.get_thread(thread_id).await
|
||||
}
|
||||
|
||||
/// Updates metadata for loaded and cold threads through one entrypoint.
|
||||
///
|
||||
/// Loaded threads route through `CodexThread`/`LiveThread`, so metadata changes stay ordered
|
||||
/// with live rollout writes. Cold threads go directly to the store, which owns unloaded JSONL
|
||||
/// compatibility and SQLite metadata updates.
|
||||
pub async fn update_thread_metadata(
|
||||
&self,
|
||||
thread_id: ThreadId,
|
||||
patch: ThreadMetadataPatch,
|
||||
include_archived: bool,
|
||||
) -> CodexResult<StoredThread> {
|
||||
if let Ok(thread) = self.get_thread(thread_id).await {
|
||||
if thread.config_snapshot().await.ephemeral {
|
||||
return Err(CodexErr::InvalidRequest(format!(
|
||||
"ephemeral thread does not support metadata updates: {thread_id}"
|
||||
)));
|
||||
}
|
||||
return thread
|
||||
.update_thread_metadata(patch, include_archived)
|
||||
.await
|
||||
.map_err(|err| thread_store_metadata_update_error(thread_id, err));
|
||||
}
|
||||
self.state
|
||||
.thread_store
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id,
|
||||
patch,
|
||||
include_archived,
|
||||
})
|
||||
.await
|
||||
.map_err(|err| match err {
|
||||
ThreadStoreError::ThreadNotFound { thread_id } => {
|
||||
CodexErr::ThreadNotFound(thread_id)
|
||||
}
|
||||
err => thread_store_metadata_update_error(thread_id, err),
|
||||
})
|
||||
}
|
||||
|
||||
/// List `thread_id` plus all known descendants in its spawn subtree.
|
||||
pub async fn list_agent_subtree_thread_ids(
|
||||
&self,
|
||||
@@ -1298,6 +1338,19 @@ fn thread_store_rollout_read_error(err: ThreadStoreError) -> CodexErr {
|
||||
}
|
||||
}
|
||||
|
||||
fn thread_store_metadata_update_error(thread_id: ThreadId, err: ThreadStoreError) -> CodexErr {
|
||||
match err {
|
||||
ThreadStoreError::ThreadNotFound { thread_id } => CodexErr::ThreadNotFound(thread_id),
|
||||
ThreadStoreError::InvalidRequest { message } => CodexErr::InvalidRequest(message),
|
||||
ThreadStoreError::Unsupported { operation } => CodexErr::UnsupportedOperation(format!(
|
||||
"thread metadata update is not supported by this store: {operation}"
|
||||
)),
|
||||
err => CodexErr::Fatal(format!(
|
||||
"failed to update thread metadata {thread_id}: {err}"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a fork snapshot cut strictly before the nth user message (0-based).
|
||||
///
|
||||
/// Out-of-range values keep the full committed history at a turn boundary, but
|
||||
|
||||
@@ -2497,38 +2497,6 @@ async fn token_count_includes_rate_limits_snapshot() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let first_token_event =
|
||||
wait_for_event(&codex, |msg| matches!(msg, EventMsg::TokenCount(_))).await;
|
||||
let rate_limit_only = match first_token_event {
|
||||
EventMsg::TokenCount(ev) => ev,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let rate_limit_json = serde_json::to_value(&rate_limit_only).unwrap();
|
||||
pretty_assertions::assert_eq!(
|
||||
rate_limit_json,
|
||||
json!({
|
||||
"info": null,
|
||||
"rate_limits": {
|
||||
"limit_id": "codex",
|
||||
"limit_name": null,
|
||||
"primary": {
|
||||
"used_percent": 12.5,
|
||||
"window_minutes": 10,
|
||||
"resets_at": 1704069000
|
||||
},
|
||||
"secondary": {
|
||||
"used_percent": 40.0,
|
||||
"window_minutes": 60,
|
||||
"resets_at": 1704074400
|
||||
},
|
||||
"credits": null,
|
||||
"plan_type": null,
|
||||
"rate_limit_reached_type": null
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
let token_event = wait_for_event(
|
||||
&codex,
|
||||
|msg| matches!(msg, EventMsg::TokenCount(ev) if ev.info.is_some()),
|
||||
|
||||
@@ -891,7 +891,7 @@ async fn handle_response_item_records_tool_result_for_custom_tool_call() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TokenCount(_))).await;
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
|
||||
|
||||
logs_assert(|lines: &[&str]| {
|
||||
let line = lines
|
||||
|
||||
@@ -17,6 +17,7 @@ use core_test_support::responses;
|
||||
use core_test_support::responses::ResponsesRequest;
|
||||
use core_test_support::responses::ev_assistant_message;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_completed_with_tokens;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::sse;
|
||||
@@ -30,6 +31,8 @@ use core_test_support::wait_for_event_match;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
fn call_output(req: &ResponsesRequest, call_id: &str) -> String {
|
||||
let raw = req.function_call_output(call_id);
|
||||
@@ -118,6 +121,7 @@ async fn request_user_input_round_trip_for_mode(mode: ModeKind) -> anyhow::Resul
|
||||
let first_response = sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
ev_function_call(call_id, "request_user_input", &request_args),
|
||||
ev_rate_limits(),
|
||||
ev_completed("resp-1"),
|
||||
]);
|
||||
responses::mount_sse_once(&server, first_response).await;
|
||||
@@ -169,6 +173,22 @@ async fn request_user_input_round_trip_for_mode(mode: ModeKind) -> anyhow::Resul
|
||||
assert_eq!(request.call_id, call_id);
|
||||
assert_eq!(request.questions.len(), 1);
|
||||
assert_eq!(request.questions[0].is_other, true);
|
||||
assert!(
|
||||
timeout(Duration::from_millis(200), async {
|
||||
loop {
|
||||
let event = match codex.next_event().await {
|
||||
Ok(event) => event,
|
||||
Err(err) => panic!("event stream should stay open: {err}"),
|
||||
};
|
||||
if matches!(event.msg, EventMsg::TokenCount(_)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
})
|
||||
.await
|
||||
.is_err(),
|
||||
"TokenCount should wait until request_user_input resolves"
|
||||
);
|
||||
|
||||
let mut answers = HashMap::new();
|
||||
answers.insert(
|
||||
@@ -185,6 +205,7 @@ async fn request_user_input_round_trip_for_mode(mode: ModeKind) -> anyhow::Resul
|
||||
})
|
||||
.await?;
|
||||
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TokenCount(_))).await;
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TurnComplete(_))).await;
|
||||
|
||||
let req = second_mock.single_request();
|
||||
@@ -202,6 +223,118 @@ async fn request_user_input_round_trip_for_mode(mode: ModeKind) -> anyhow::Resul
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ev_rate_limits() -> Value {
|
||||
json!({
|
||||
"type": "codex.rate_limits",
|
||||
"plan_type": "plus",
|
||||
"rate_limits": {
|
||||
"allowed": true,
|
||||
"limit_reached": false,
|
||||
"primary": {
|
||||
"used_percent": 42,
|
||||
"window_minutes": 60,
|
||||
"reset_at": 1700000000
|
||||
},
|
||||
"secondary": null
|
||||
},
|
||||
"code_review_rate_limits": null,
|
||||
"credits": null,
|
||||
"promo": null
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn request_user_input_interrupt_emits_deferred_token_count() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let TestCodex {
|
||||
codex,
|
||||
cwd,
|
||||
session_configured,
|
||||
..
|
||||
} = test_codex().build(&server).await?;
|
||||
|
||||
let call_id = "user-input-interrupt";
|
||||
let request_args = json!({
|
||||
"questions": [{
|
||||
"id": "confirm_path",
|
||||
"header": "Confirm",
|
||||
"question": "Proceed with the plan?",
|
||||
"options": [{
|
||||
"label": "Yes (Recommended)",
|
||||
"description": "Continue the current plan."
|
||||
}, {
|
||||
"label": "No",
|
||||
"description": "Stop and revisit the approach."
|
||||
}]
|
||||
}]
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let response = sse(vec![
|
||||
ev_response_created("resp-interrupt"),
|
||||
ev_function_call(call_id, "request_user_input", &request_args),
|
||||
ev_completed_with_tokens("resp-interrupt", /*total_tokens*/ 77),
|
||||
]);
|
||||
responses::mount_sse_once(&server, response).await;
|
||||
|
||||
let (sandbox_policy, permission_profile) =
|
||||
turn_permission_fields(PermissionProfile::Disabled, cwd.path());
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
environments: None,
|
||||
items: vec![UserInput::Text {
|
||||
text: "please confirm".into(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
approvals_reviewer: None,
|
||||
sandbox_policy,
|
||||
permission_profile,
|
||||
model: session_configured.model.clone(),
|
||||
effort: None,
|
||||
summary: None,
|
||||
service_tier: None,
|
||||
collaboration_mode: Some(CollaborationMode {
|
||||
mode: ModeKind::Plan,
|
||||
settings: Settings {
|
||||
model: session_configured.model,
|
||||
reasoning_effort: None,
|
||||
developer_instructions: None,
|
||||
},
|
||||
}),
|
||||
personality: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let request = wait_for_event_match(&codex, |event| match event {
|
||||
EventMsg::RequestUserInput(request) => Some(request.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
|
||||
codex.submit(Op::Interrupt).await?;
|
||||
|
||||
let token_count = wait_for_event_match(&codex, |event| match event {
|
||||
EventMsg::TokenCount(token_count) => Some(token_count.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
assert_eq!(
|
||||
token_count
|
||||
.info
|
||||
.map(|info| info.total_token_usage.total_tokens),
|
||||
Some(77)
|
||||
);
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TurnAborted(_))).await;
|
||||
|
||||
assert_eq!(request.call_id, call_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn assert_request_user_input_rejected<F>(mode_name: &str, build_mode: F) -> anyhow::Result<()>
|
||||
where
|
||||
F: FnOnce(String) -> CollaborationMode,
|
||||
|
||||
@@ -109,7 +109,6 @@ async fn resume_includes_initial_messages_from_rollout_events() -> Result<()> {
|
||||
[
|
||||
EventMsg::TurnStarted(_),
|
||||
EventMsg::UserMessage(_),
|
||||
EventMsg::TokenCount(_),
|
||||
EventMsg::AgentMessage(_),
|
||||
EventMsg::TokenCount(_),
|
||||
EventMsg::TurnComplete(_),
|
||||
@@ -126,7 +125,6 @@ async fn resume_includes_initial_messages_from_rollout_events() -> Result<()> {
|
||||
[
|
||||
EventMsg::TurnStarted(started),
|
||||
EventMsg::UserMessage(first_user),
|
||||
EventMsg::TokenCount(_),
|
||||
EventMsg::AgentMessage(assistant_message),
|
||||
EventMsg::TokenCount(_),
|
||||
EventMsg::TurnComplete(completed),
|
||||
@@ -196,7 +194,6 @@ async fn resume_includes_initial_messages_from_reasoning_events() -> Result<()>
|
||||
[
|
||||
EventMsg::TurnStarted(_),
|
||||
EventMsg::UserMessage(_),
|
||||
EventMsg::TokenCount(_),
|
||||
EventMsg::AgentReasoning(_),
|
||||
EventMsg::AgentReasoningRawContent(_),
|
||||
EventMsg::AgentMessage(_),
|
||||
@@ -215,7 +212,6 @@ async fn resume_includes_initial_messages_from_reasoning_events() -> Result<()>
|
||||
[
|
||||
EventMsg::TurnStarted(started),
|
||||
EventMsg::UserMessage(first_user),
|
||||
EventMsg::TokenCount(_),
|
||||
EventMsg::AgentReasoning(reasoning),
|
||||
EventMsg::AgentReasoningRawContent(raw),
|
||||
EventMsg::AgentMessage(assistant_message),
|
||||
|
||||
@@ -4,7 +4,6 @@ use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use chrono::Utc;
|
||||
use codex_core::EventPersistenceMode;
|
||||
use codex_core::RolloutRecorder;
|
||||
use codex_core::RolloutRecorderParams;
|
||||
use codex_core::config::ConfigBuilder;
|
||||
@@ -189,10 +188,7 @@ async fn find_locates_rollout_file_written_by_recorder() -> std::io::Result<()>
|
||||
/*thread_source*/ None,
|
||||
BaseInstructions::default(),
|
||||
Vec::new(),
|
||||
EventPersistenceMode::Limited,
|
||||
),
|
||||
/*state_db_ctx*/ None,
|
||||
/*state_builder*/ None,
|
||||
)
|
||||
.await?;
|
||||
recorder.persist().await?;
|
||||
|
||||
@@ -55,6 +55,7 @@ pub use list::rollout_date_parts;
|
||||
pub use metadata::builder_from_items;
|
||||
pub use policy::EventPersistenceMode;
|
||||
pub use policy::is_persisted_rollout_item;
|
||||
pub use policy::persisted_rollout_items;
|
||||
pub use policy::should_persist_response_item_for_memories;
|
||||
pub use recorder::RolloutRecorder;
|
||||
pub use recorder::RolloutRecorderParams;
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::RolloutItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_utils_string::truncate_middle_chars;
|
||||
|
||||
const PERSISTED_EXEC_AGGREGATED_OUTPUT_MAX_BYTES: usize = 10_000;
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
pub enum EventPersistenceMode {
|
||||
@@ -22,6 +25,43 @@ pub fn is_persisted_rollout_item(item: &RolloutItem, mode: EventPersistenceMode)
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the canonical rollout items that should be persisted for a live append.
|
||||
pub fn persisted_rollout_items(
|
||||
items: &[RolloutItem],
|
||||
mode: EventPersistenceMode,
|
||||
) -> Vec<RolloutItem> {
|
||||
let mut persisted = Vec::new();
|
||||
for item in items {
|
||||
if is_persisted_rollout_item(item, mode) {
|
||||
persisted.push(sanitize_rollout_item_for_persistence(item.clone(), mode));
|
||||
}
|
||||
}
|
||||
persisted
|
||||
}
|
||||
|
||||
fn sanitize_rollout_item_for_persistence(
|
||||
item: RolloutItem,
|
||||
mode: EventPersistenceMode,
|
||||
) -> RolloutItem {
|
||||
if mode != EventPersistenceMode::Extended {
|
||||
return item;
|
||||
}
|
||||
|
||||
match item {
|
||||
RolloutItem::EventMsg(EventMsg::ExecCommandEnd(mut event)) => {
|
||||
event.aggregated_output = truncate_middle_chars(
|
||||
&event.aggregated_output,
|
||||
PERSISTED_EXEC_AGGREGATED_OUTPUT_MAX_BYTES,
|
||||
);
|
||||
event.stdout.clear();
|
||||
event.stderr.clear();
|
||||
event.formatted_output.clear();
|
||||
RolloutItem::EventMsg(EventMsg::ExecCommandEnd(event))
|
||||
}
|
||||
_ => item,
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether a `ResponseItem` should be persisted in rollout files.
|
||||
#[inline]
|
||||
pub fn should_persist_response_item(item: &ResponseItem) -> bool {
|
||||
|
||||
@@ -8,16 +8,11 @@ use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use chrono::DateTime;
|
||||
use chrono::SecondsFormat;
|
||||
use chrono::Utc;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::dynamic_tools::DynamicToolSpec;
|
||||
use codex_protocol::models::BaseInstructions;
|
||||
use codex_utils_string::truncate_middle_chars;
|
||||
use serde_json::Value;
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
@@ -47,15 +42,13 @@ use super::list::get_threads_in_root;
|
||||
use super::list::parse_cursor;
|
||||
use super::list::parse_timestamp_uuid_from_filename;
|
||||
use super::metadata;
|
||||
use super::policy::EventPersistenceMode;
|
||||
use super::policy::is_persisted_rollout_item;
|
||||
use super::session_index::find_thread_names_by_ids;
|
||||
use crate::config::RolloutConfigView;
|
||||
use crate::default_client::originator;
|
||||
use crate::state_db;
|
||||
use crate::state_db::StateDbHandle;
|
||||
use codex_git_utils::collect_git_info;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_git_utils::get_git_repo_root;
|
||||
use codex_protocol::protocol::GitInfo as ProtocolGitInfo;
|
||||
use codex_protocol::protocol::InitialHistory;
|
||||
use codex_protocol::protocol::ResumedHistory;
|
||||
@@ -66,11 +59,9 @@ use codex_protocol::protocol::SessionMetaLine;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::ThreadSource;
|
||||
use codex_state::StateRuntime;
|
||||
use codex_state::ThreadMetadataBuilder;
|
||||
use codex_utils_path as path_utils;
|
||||
|
||||
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
|
||||
/// every update.
|
||||
/// Writes canonical session rollout items to JSONL.
|
||||
///
|
||||
/// Rollouts are recorded as JSONL and can be inspected with tools such as:
|
||||
///
|
||||
@@ -83,7 +74,6 @@ pub struct RolloutRecorder {
|
||||
tx: Sender<RolloutCmd>,
|
||||
writer_task: Arc<RolloutWriterTask>,
|
||||
pub(crate) rollout_path: PathBuf,
|
||||
event_persistence_mode: EventPersistenceMode,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -95,11 +85,9 @@ pub enum RolloutRecorderParams {
|
||||
thread_source: Option<ThreadSource>,
|
||||
base_instructions: BaseInstructions,
|
||||
dynamic_tools: Vec<DynamicToolSpec>,
|
||||
event_persistence_mode: EventPersistenceMode,
|
||||
},
|
||||
Resume {
|
||||
path: PathBuf,
|
||||
event_persistence_mode: EventPersistenceMode,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -172,7 +160,6 @@ impl RolloutRecorderParams {
|
||||
thread_source: Option<ThreadSource>,
|
||||
base_instructions: BaseInstructions,
|
||||
dynamic_tools: Vec<DynamicToolSpec>,
|
||||
event_persistence_mode: EventPersistenceMode,
|
||||
) -> Self {
|
||||
Self::Create {
|
||||
conversation_id,
|
||||
@@ -181,42 +168,11 @@ impl RolloutRecorderParams {
|
||||
thread_source,
|
||||
base_instructions,
|
||||
dynamic_tools,
|
||||
event_persistence_mode,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resume(path: PathBuf, event_persistence_mode: EventPersistenceMode) -> Self {
|
||||
Self::Resume {
|
||||
path,
|
||||
event_persistence_mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const PERSISTED_EXEC_AGGREGATED_OUTPUT_MAX_BYTES: usize = 10_000;
|
||||
|
||||
fn sanitize_rollout_item_for_persistence(
|
||||
item: RolloutItem,
|
||||
mode: EventPersistenceMode,
|
||||
) -> RolloutItem {
|
||||
if mode != EventPersistenceMode::Extended {
|
||||
return item;
|
||||
}
|
||||
|
||||
match item {
|
||||
RolloutItem::EventMsg(EventMsg::ExecCommandEnd(mut event)) => {
|
||||
// Persist only a bounded aggregated summary of command output.
|
||||
event.aggregated_output = truncate_middle_chars(
|
||||
&event.aggregated_output,
|
||||
PERSISTED_EXEC_AGGREGATED_OUTPUT_MAX_BYTES,
|
||||
);
|
||||
// Drop unnecessary fields from rollout storage since aggregated_output is all we need.
|
||||
event.stdout.clear();
|
||||
event.stderr.clear();
|
||||
event.formatted_output.clear();
|
||||
RolloutItem::EventMsg(EventMsg::ExecCommandEnd(event))
|
||||
}
|
||||
_ => item,
|
||||
pub fn resume(path: PathBuf) -> Self {
|
||||
Self::Resume { path }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -691,80 +647,65 @@ impl RolloutRecorder {
|
||||
pub async fn new(
|
||||
config: &impl RolloutConfigView,
|
||||
params: RolloutRecorderParams,
|
||||
state_db_ctx: Option<StateDbHandle>,
|
||||
state_builder: Option<ThreadMetadataBuilder>,
|
||||
) -> std::io::Result<Self> {
|
||||
let (file, deferred_log_file_info, rollout_path, meta, event_persistence_mode) =
|
||||
match params {
|
||||
RolloutRecorderParams::Create {
|
||||
conversation_id,
|
||||
let (file, deferred_log_file_info, rollout_path, meta) = match params {
|
||||
RolloutRecorderParams::Create {
|
||||
conversation_id,
|
||||
forked_from_id,
|
||||
source,
|
||||
thread_source,
|
||||
base_instructions,
|
||||
dynamic_tools,
|
||||
} => {
|
||||
let log_file_info = precompute_log_file_info(config, conversation_id)?;
|
||||
let path = log_file_info.path.clone();
|
||||
let session_id = log_file_info.conversation_id;
|
||||
let started_at = log_file_info.timestamp;
|
||||
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
let timestamp = started_at
|
||||
.to_offset(time::UtcOffset::UTC)
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
let session_meta = SessionMeta {
|
||||
id: session_id,
|
||||
forked_from_id,
|
||||
timestamp,
|
||||
cwd: config.cwd().to_path_buf(),
|
||||
originator: originator().value,
|
||||
cli_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
agent_nickname: source.get_nickname(),
|
||||
agent_role: source.get_agent_role(),
|
||||
agent_path: source.get_agent_path().map(Into::into),
|
||||
source,
|
||||
thread_source,
|
||||
base_instructions,
|
||||
dynamic_tools,
|
||||
event_persistence_mode,
|
||||
} => {
|
||||
let log_file_info = precompute_log_file_info(config, conversation_id)?;
|
||||
let path = log_file_info.path.clone();
|
||||
let session_id = log_file_info.conversation_id;
|
||||
let started_at = log_file_info.timestamp;
|
||||
model_provider: Some(config.model_provider_id().to_string()),
|
||||
base_instructions: Some(base_instructions),
|
||||
dynamic_tools: if dynamic_tools.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(dynamic_tools)
|
||||
},
|
||||
memory_mode: (!config.generate_memories()).then_some("disabled".to_string()),
|
||||
};
|
||||
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
let timestamp = started_at
|
||||
.to_offset(time::UtcOffset::UTC)
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
let session_meta = SessionMeta {
|
||||
id: session_id,
|
||||
forked_from_id,
|
||||
timestamp,
|
||||
cwd: config.cwd().to_path_buf(),
|
||||
originator: originator().value,
|
||||
cli_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
agent_nickname: source.get_nickname(),
|
||||
agent_role: source.get_agent_role(),
|
||||
agent_path: source.get_agent_path().map(Into::into),
|
||||
source,
|
||||
thread_source,
|
||||
model_provider: Some(config.model_provider_id().to_string()),
|
||||
base_instructions: Some(base_instructions),
|
||||
dynamic_tools: if dynamic_tools.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(dynamic_tools)
|
||||
},
|
||||
memory_mode: (!config.generate_memories())
|
||||
.then_some("disabled".to_string()),
|
||||
};
|
||||
|
||||
(
|
||||
None,
|
||||
Some(log_file_info),
|
||||
path,
|
||||
Some(session_meta),
|
||||
event_persistence_mode,
|
||||
)
|
||||
}
|
||||
RolloutRecorderParams::Resume {
|
||||
path,
|
||||
event_persistence_mode,
|
||||
} => (
|
||||
Some(
|
||||
tokio::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.open(&path)
|
||||
.await?,
|
||||
),
|
||||
None,
|
||||
path,
|
||||
None,
|
||||
event_persistence_mode,
|
||||
(None, Some(log_file_info), path, Some(session_meta))
|
||||
}
|
||||
RolloutRecorderParams::Resume { path } => (
|
||||
Some(
|
||||
tokio::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.open(&path)
|
||||
.await?,
|
||||
),
|
||||
};
|
||||
None,
|
||||
path,
|
||||
None,
|
||||
),
|
||||
};
|
||||
|
||||
// Clone the cwd for the spawned task to collect git info asynchronously
|
||||
let cwd = config.cwd().to_path_buf();
|
||||
@@ -779,9 +720,6 @@ impl RolloutRecorder {
|
||||
let writer_task = Arc::new(RolloutWriterTask::new());
|
||||
let writer_task_for_spawn = Arc::clone(&writer_task);
|
||||
let rollout_path_for_spawn = rollout_path.clone();
|
||||
let default_provider = config.model_provider_id().to_string();
|
||||
let generate_memories = config.generate_memories();
|
||||
let state_db_ctx_for_spawn = state_db_ctx.clone();
|
||||
let handle = tokio::task::spawn(async move {
|
||||
let result = rollout_writer(
|
||||
file,
|
||||
@@ -790,10 +728,6 @@ impl RolloutRecorder {
|
||||
meta,
|
||||
cwd,
|
||||
rollout_path_for_spawn.clone(),
|
||||
state_db_ctx_for_spawn,
|
||||
state_builder,
|
||||
default_provider,
|
||||
generate_memories,
|
||||
)
|
||||
.await;
|
||||
if let Err(err) = result {
|
||||
@@ -812,7 +746,6 @@ impl RolloutRecorder {
|
||||
tx,
|
||||
writer_task,
|
||||
rollout_path,
|
||||
event_persistence_mode,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -820,24 +753,12 @@ impl RolloutRecorder {
|
||||
self.rollout_path.as_path()
|
||||
}
|
||||
|
||||
pub async fn record_items(&self, items: &[RolloutItem]) -> std::io::Result<()> {
|
||||
let mut filtered = Vec::new();
|
||||
for item in items {
|
||||
// Note that function calls may look a bit strange if they are
|
||||
// "fully qualified MCP tool calls," so we could consider
|
||||
// reformatting them in that case.
|
||||
if is_persisted_rollout_item(item, self.event_persistence_mode) {
|
||||
filtered.push(sanitize_rollout_item_for_persistence(
|
||||
item.clone(),
|
||||
self.event_persistence_mode,
|
||||
));
|
||||
}
|
||||
}
|
||||
if filtered.is_empty() {
|
||||
pub async fn record_canonical_items(&self, items: &[RolloutItem]) -> std::io::Result<()> {
|
||||
if items.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.tx
|
||||
.send(RolloutCmd::AddItems(filtered))
|
||||
.send(RolloutCmd::AddItems(items.to_vec()))
|
||||
.await
|
||||
.map_err(|e| {
|
||||
self.writer_task.terminal_failure().unwrap_or_else(|| {
|
||||
@@ -1459,48 +1380,17 @@ struct RolloutWriterState {
|
||||
meta: Option<SessionMeta>,
|
||||
cwd: PathBuf,
|
||||
rollout_path: PathBuf,
|
||||
state_db_ctx: Option<StateDbHandle>,
|
||||
state_builder: Option<ThreadMetadataBuilder>,
|
||||
default_provider: String,
|
||||
generate_memories: bool,
|
||||
thread_updated_at_touch: ThreadUpdatedAtTouch,
|
||||
last_logged_error: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
const THREAD_UPDATED_AT_TOUCH_INTERVAL: Duration = Duration::from_secs(5);
|
||||
#[cfg(test)]
|
||||
const THREAD_UPDATED_AT_TOUCH_INTERVAL: Duration = Duration::from_millis(50);
|
||||
|
||||
#[derive(Default)]
|
||||
struct ThreadUpdatedAtTouch {
|
||||
last_persisted_at: Option<Instant>,
|
||||
pending_touch: Option<(ThreadId, DateTime<Utc>)>,
|
||||
}
|
||||
|
||||
impl ThreadUpdatedAtTouch {
|
||||
fn mark_persisted(&mut self, now: Instant) {
|
||||
self.last_persisted_at = Some(now);
|
||||
self.pending_touch = None;
|
||||
}
|
||||
}
|
||||
|
||||
impl RolloutWriterState {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
file: Option<tokio::fs::File>,
|
||||
deferred_log_file_info: Option<LogFileInfo>,
|
||||
meta: Option<SessionMeta>,
|
||||
cwd: PathBuf,
|
||||
rollout_path: PathBuf,
|
||||
state_db_ctx: Option<StateDbHandle>,
|
||||
mut state_builder: Option<ThreadMetadataBuilder>,
|
||||
default_provider: String,
|
||||
generate_memories: bool,
|
||||
) -> Self {
|
||||
if let Some(builder) = state_builder.as_mut() {
|
||||
builder.rollout_path = rollout_path.clone();
|
||||
}
|
||||
Self {
|
||||
writer: file.map(|file| JsonlWriter { file }),
|
||||
deferred_log_file_info,
|
||||
@@ -1508,11 +1398,6 @@ impl RolloutWriterState {
|
||||
meta,
|
||||
cwd,
|
||||
rollout_path,
|
||||
state_db_ctx,
|
||||
state_builder,
|
||||
default_provider,
|
||||
generate_memories,
|
||||
thread_updated_at_touch: ThreadUpdatedAtTouch::default(),
|
||||
last_logged_error: None,
|
||||
}
|
||||
}
|
||||
@@ -1545,19 +1430,7 @@ impl RolloutWriterState {
|
||||
if self.is_deferred() && self.pending_items.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.write_pending_with_recovery("shutdown").await?;
|
||||
if let Some((thread_id, updated_at)) = self.thread_updated_at_touch.pending_touch.take()
|
||||
&& state_db::touch_thread_updated_at(
|
||||
self.state_db_ctx.as_deref(),
|
||||
Some(thread_id),
|
||||
updated_at,
|
||||
"rollout_writer_shutdown",
|
||||
)
|
||||
.await
|
||||
{
|
||||
self.thread_updated_at_touch.mark_persisted(Instant::now());
|
||||
}
|
||||
Ok(())
|
||||
self.write_pending_with_recovery("shutdown").await
|
||||
}
|
||||
|
||||
async fn write_pending_with_recovery(&mut self, operation: &str) -> std::io::Result<()> {
|
||||
@@ -1625,18 +1498,7 @@ impl RolloutWriterState {
|
||||
let Some(session_meta) = self.meta.as_ref().cloned() else {
|
||||
return Ok(());
|
||||
};
|
||||
write_session_meta(
|
||||
self.writer.as_mut(),
|
||||
session_meta,
|
||||
&self.cwd,
|
||||
&self.rollout_path,
|
||||
self.state_db_ctx.as_deref(),
|
||||
&mut self.state_builder,
|
||||
self.default_provider.as_str(),
|
||||
self.generate_memories,
|
||||
&mut self.thread_updated_at_touch,
|
||||
)
|
||||
.await?;
|
||||
write_session_meta(self.writer.as_mut(), session_meta, &self.cwd).await?;
|
||||
self.meta = None;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1669,25 +1531,13 @@ impl RolloutWriterState {
|
||||
}
|
||||
|
||||
if written_count > 0 {
|
||||
let written_items: Vec<RolloutItem> =
|
||||
self.pending_items.drain(..written_count).collect();
|
||||
sync_thread_state_after_write(
|
||||
self.state_db_ctx.as_deref(),
|
||||
&self.rollout_path,
|
||||
self.state_builder.as_ref(),
|
||||
written_items.as_slice(),
|
||||
self.default_provider.as_str(),
|
||||
/*new_thread_memory_mode*/ None,
|
||||
&mut self.thread_updated_at_touch,
|
||||
)
|
||||
.await;
|
||||
self.pending_items.drain(..written_count);
|
||||
}
|
||||
|
||||
write_result
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn rollout_writer(
|
||||
file: Option<tokio::fs::File>,
|
||||
deferred_log_file_info: Option<LogFileInfo>,
|
||||
@@ -1695,22 +1545,8 @@ async fn rollout_writer(
|
||||
meta: Option<SessionMeta>,
|
||||
cwd: PathBuf,
|
||||
rollout_path: PathBuf,
|
||||
state_db_ctx: Option<StateDbHandle>,
|
||||
state_builder: Option<ThreadMetadataBuilder>,
|
||||
default_provider: String,
|
||||
generate_memories: bool,
|
||||
) -> std::io::Result<()> {
|
||||
let mut state = RolloutWriterState::new(
|
||||
file,
|
||||
deferred_log_file_info,
|
||||
meta,
|
||||
cwd,
|
||||
rollout_path,
|
||||
state_db_ctx,
|
||||
state_builder,
|
||||
default_provider,
|
||||
generate_memories,
|
||||
);
|
||||
let mut state = RolloutWriterState::new(file, deferred_log_file_info, meta, cwd, rollout_path);
|
||||
|
||||
// Process rollout commands
|
||||
while let Some(cmd) = rx.recv().await {
|
||||
@@ -1740,116 +1576,36 @@ async fn rollout_writer(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn write_session_meta(
|
||||
mut writer: Option<&mut JsonlWriter>,
|
||||
session_meta: SessionMeta,
|
||||
cwd: &Path,
|
||||
rollout_path: &Path,
|
||||
state_db_ctx: Option<&StateRuntime>,
|
||||
state_builder: &mut Option<ThreadMetadataBuilder>,
|
||||
default_provider: &str,
|
||||
generate_memories: bool,
|
||||
thread_updated_at_touch: &mut ThreadUpdatedAtTouch,
|
||||
) -> std::io::Result<()> {
|
||||
let git_info = collect_git_info(cwd).await.map(|info| ProtocolGitInfo {
|
||||
commit_hash: info.commit_hash,
|
||||
branch: info.branch,
|
||||
repository_url: info.repository_url,
|
||||
});
|
||||
let git_info = if get_git_repo_root(cwd).is_some() {
|
||||
collect_git_info(cwd).await.map(|info| ProtocolGitInfo {
|
||||
commit_hash: info.commit_hash,
|
||||
branch: info.branch,
|
||||
repository_url: info.repository_url,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let session_meta_line = SessionMetaLine {
|
||||
meta: session_meta,
|
||||
git: git_info,
|
||||
};
|
||||
if state_db_ctx.is_some() {
|
||||
*state_builder = metadata::builder_from_session_meta(&session_meta_line, rollout_path);
|
||||
}
|
||||
|
||||
let rollout_item = RolloutItem::SessionMeta(session_meta_line);
|
||||
if let Some(writer) = writer.as_mut() {
|
||||
writer.write_rollout_item(&rollout_item).await?;
|
||||
}
|
||||
sync_thread_state_after_write(
|
||||
state_db_ctx,
|
||||
rollout_path,
|
||||
state_builder.as_ref(),
|
||||
std::slice::from_ref(&rollout_item),
|
||||
default_provider,
|
||||
(!generate_memories).then_some("disabled"),
|
||||
thread_updated_at_touch,
|
||||
)
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn sync_thread_state_after_write(
|
||||
state_db_ctx: Option<&StateRuntime>,
|
||||
rollout_path: &Path,
|
||||
state_builder: Option<&ThreadMetadataBuilder>,
|
||||
items: &[RolloutItem],
|
||||
default_provider: &str,
|
||||
new_thread_memory_mode: Option<&str>,
|
||||
thread_updated_at_touch: &mut ThreadUpdatedAtTouch,
|
||||
) {
|
||||
let updated_at = Utc::now();
|
||||
let now = Instant::now();
|
||||
if new_thread_memory_mode.is_some()
|
||||
|| items
|
||||
.iter()
|
||||
.any(codex_state::rollout_item_affects_thread_metadata)
|
||||
{
|
||||
state_db::apply_rollout_items(
|
||||
state_db_ctx,
|
||||
rollout_path,
|
||||
default_provider,
|
||||
state_builder,
|
||||
items,
|
||||
"rollout_writer",
|
||||
new_thread_memory_mode,
|
||||
Some(updated_at),
|
||||
)
|
||||
.await;
|
||||
thread_updated_at_touch.mark_persisted(now);
|
||||
return;
|
||||
}
|
||||
|
||||
let thread_id = state_builder
|
||||
.map(|builder| builder.id)
|
||||
.or_else(|| metadata::builder_from_items(items, rollout_path).map(|builder| builder.id));
|
||||
if thread_updated_at_touch
|
||||
.last_persisted_at
|
||||
.is_some_and(|last_persisted_at| {
|
||||
now.duration_since(last_persisted_at) < THREAD_UPDATED_AT_TOUCH_INTERVAL
|
||||
})
|
||||
{
|
||||
thread_updated_at_touch.pending_touch = thread_id.map(|thread_id| (thread_id, updated_at));
|
||||
return;
|
||||
}
|
||||
|
||||
if state_db::touch_thread_updated_at(state_db_ctx, thread_id, updated_at, "rollout_writer")
|
||||
.await
|
||||
{
|
||||
thread_updated_at_touch.mark_persisted(now);
|
||||
return;
|
||||
}
|
||||
state_db::apply_rollout_items(
|
||||
state_db_ctx,
|
||||
rollout_path,
|
||||
default_provider,
|
||||
state_builder,
|
||||
items,
|
||||
"rollout_writer",
|
||||
new_thread_memory_mode,
|
||||
Some(updated_at),
|
||||
)
|
||||
.await;
|
||||
thread_updated_at_touch.mark_persisted(now);
|
||||
}
|
||||
|
||||
/// Append one already-filtered rollout item to an existing rollout JSONL file.
|
||||
///
|
||||
/// This is for metadata updates to unloaded threads. Live sessions should use
|
||||
/// `RolloutRecorder::record_items` so rollout and SQLite updates remain ordered
|
||||
/// `RolloutRecorder::record_canonical_items` so rollout writes remain ordered
|
||||
/// with the rest of the session stream.
|
||||
pub async fn append_rollout_item_to_path(
|
||||
rollout_path: &Path,
|
||||
|
||||
@@ -372,10 +372,7 @@ async fn recorder_materializes_on_flush_with_pending_items() -> std::io::Result<
|
||||
/*thread_source*/ None,
|
||||
BaseInstructions::default(),
|
||||
Vec::new(),
|
||||
EventPersistenceMode::Limited,
|
||||
),
|
||||
/*state_db_ctx*/ None,
|
||||
/*state_builder*/ None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -386,7 +383,7 @@ async fn recorder_materializes_on_flush_with_pending_items() -> std::io::Result<
|
||||
);
|
||||
|
||||
recorder
|
||||
.record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
|
||||
.record_canonical_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
|
||||
AgentMessageEvent {
|
||||
message: "buffered-event".to_string(),
|
||||
phase: None,
|
||||
@@ -401,7 +398,7 @@ async fn recorder_materializes_on_flush_with_pending_items() -> std::io::Result<
|
||||
);
|
||||
|
||||
recorder
|
||||
.record_items(&[RolloutItem::EventMsg(EventMsg::UserMessage(
|
||||
.record_canonical_items(&[RolloutItem::EventMsg(EventMsg::UserMessage(
|
||||
UserMessageEvent {
|
||||
message: "first-user-message".to_string(),
|
||||
images: None,
|
||||
@@ -453,16 +450,13 @@ async fn persist_reports_filesystem_error_and_retries_buffered_items() -> std::i
|
||||
/*thread_source*/ None,
|
||||
BaseInstructions::default(),
|
||||
Vec::new(),
|
||||
EventPersistenceMode::Limited,
|
||||
),
|
||||
/*state_db_ctx*/ None,
|
||||
/*state_builder*/ None,
|
||||
)
|
||||
.await?;
|
||||
let rollout_path = recorder.rollout_path().to_path_buf();
|
||||
|
||||
recorder
|
||||
.record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
|
||||
.record_canonical_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
|
||||
AgentMessageEvent {
|
||||
message: "buffered-before-persist".to_string(),
|
||||
phase: None,
|
||||
@@ -498,7 +492,6 @@ async fn persist_reports_filesystem_error_and_retries_buffered_items() -> std::i
|
||||
#[tokio::test]
|
||||
async fn writer_state_retries_write_error_before_reporting_flush_success() -> std::io::Result<()> {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
let rollout_path = home.path().join("rollout.jsonl");
|
||||
File::create(&rollout_path)?;
|
||||
let read_only_file = std::fs::OpenOptions::new().read(true).open(&rollout_path)?;
|
||||
@@ -508,10 +501,6 @@ async fn writer_state_retries_write_error_before_reporting_flush_success() -> st
|
||||
/*meta*/ None,
|
||||
home.path().to_path_buf(),
|
||||
rollout_path.clone(),
|
||||
/*state_db_ctx*/ None,
|
||||
/*state_builder*/ None,
|
||||
config.model_provider_id.clone(),
|
||||
config.generate_memories,
|
||||
);
|
||||
state.add_items(vec![RolloutItem::EventMsg(EventMsg::AgentMessage(
|
||||
AgentMessageEvent {
|
||||
@@ -530,216 +519,6 @@ async fn writer_state_retries_write_error_before_reporting_flush_success() -> st
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn metadata_irrelevant_events_coalesce_state_db_updated_at() -> std::io::Result<()> {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
|
||||
let state_db = StateRuntime::init(home.path().to_path_buf(), config.model_provider_id.clone())
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
state_db
|
||||
.mark_backfill_complete(/*last_watermark*/ None)
|
||||
.await
|
||||
.expect("backfill should be complete");
|
||||
|
||||
let thread_id = ThreadId::new();
|
||||
let recorder = RolloutRecorder::new(
|
||||
&config,
|
||||
RolloutRecorderParams::new(
|
||||
thread_id,
|
||||
/*forked_from_id*/ None,
|
||||
SessionSource::Cli,
|
||||
/*thread_source*/ None,
|
||||
BaseInstructions::default(),
|
||||
Vec::new(),
|
||||
EventPersistenceMode::Limited,
|
||||
),
|
||||
Some(state_db.clone()),
|
||||
/*state_builder*/ None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
recorder
|
||||
.record_items(&[RolloutItem::EventMsg(EventMsg::UserMessage(
|
||||
UserMessageEvent {
|
||||
message: "first-user-message".to_string(),
|
||||
images: None,
|
||||
local_images: Vec::new(),
|
||||
text_elements: Vec::new(),
|
||||
},
|
||||
))])
|
||||
.await?;
|
||||
recorder.persist().await?;
|
||||
recorder.flush().await?;
|
||||
let initial_thread = state_db
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("thread should load")
|
||||
.expect("thread should exist");
|
||||
let initial_updated_at = initial_thread.updated_at;
|
||||
let initial_title = initial_thread.title.clone();
|
||||
let initial_first_user_message = initial_thread.first_user_message.clone();
|
||||
|
||||
recorder
|
||||
.record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
|
||||
AgentMessageEvent {
|
||||
message: "assistant text".to_string(),
|
||||
phase: None,
|
||||
memory_citation: None,
|
||||
},
|
||||
))])
|
||||
.await?;
|
||||
recorder.flush().await?;
|
||||
|
||||
let updated_thread = state_db
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("thread should load after agent message")
|
||||
.expect("thread should still exist");
|
||||
|
||||
assert_eq!(updated_thread.updated_at, initial_updated_at);
|
||||
assert_eq!(updated_thread.title, initial_title);
|
||||
assert_eq!(
|
||||
updated_thread.first_user_message,
|
||||
initial_first_user_message
|
||||
);
|
||||
|
||||
tokio::time::sleep(THREAD_UPDATED_AT_TOUCH_INTERVAL + Duration::from_millis(10)).await;
|
||||
|
||||
recorder
|
||||
.record_items(&[RolloutItem::EventMsg(EventMsg::AgentMessage(
|
||||
AgentMessageEvent {
|
||||
message: "more assistant text".to_string(),
|
||||
phase: None,
|
||||
memory_citation: None,
|
||||
},
|
||||
))])
|
||||
.await?;
|
||||
recorder.flush().await?;
|
||||
|
||||
let refreshed_thread = state_db
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("thread should load after refresh")
|
||||
.expect("thread should still exist");
|
||||
assert!(refreshed_thread.updated_at > initial_updated_at);
|
||||
assert_eq!(refreshed_thread.title, initial_title);
|
||||
assert_eq!(
|
||||
refreshed_thread.first_user_message,
|
||||
initial_first_user_message
|
||||
);
|
||||
|
||||
recorder.shutdown().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shutdown_flushes_pending_metadata_irrelevant_updated_at() -> std::io::Result<()> {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
|
||||
let state_db = StateRuntime::init(home.path().to_path_buf(), config.model_provider_id.clone())
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
state_db
|
||||
.mark_backfill_complete(/*last_watermark*/ None)
|
||||
.await
|
||||
.expect("backfill should be complete");
|
||||
|
||||
let thread_id = ThreadId::new();
|
||||
let rollout_path = home.path().join("rollout.jsonl");
|
||||
let initial_updated_at = Utc.with_ymd_and_hms(2026, 5, 7, 7, 37, 8).unwrap();
|
||||
let builder = ThreadMetadataBuilder::new(
|
||||
thread_id,
|
||||
rollout_path.clone(),
|
||||
initial_updated_at,
|
||||
SessionSource::Cli,
|
||||
);
|
||||
state_db
|
||||
.upsert_thread(&builder.build(config.model_provider_id.as_str()))
|
||||
.await
|
||||
.expect("thread should be inserted");
|
||||
|
||||
File::create(&rollout_path)?;
|
||||
let rollout_file = std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.open(&rollout_path)?;
|
||||
let mut state = RolloutWriterState::new(
|
||||
Some(tokio::fs::File::from_std(rollout_file)),
|
||||
/*deferred_log_file_info*/ None,
|
||||
/*meta*/ None,
|
||||
home.path().to_path_buf(),
|
||||
rollout_path,
|
||||
Some(state_db.clone()),
|
||||
Some(builder),
|
||||
config.model_provider_id.clone(),
|
||||
config.generate_memories,
|
||||
);
|
||||
let pending_updated_at = initial_updated_at + chrono::Duration::seconds(1);
|
||||
state.thread_updated_at_touch.pending_touch = Some((thread_id, pending_updated_at));
|
||||
|
||||
state.shutdown().await?;
|
||||
|
||||
assert_eq!(
|
||||
state_db
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("thread should load after shutdown")
|
||||
.expect("thread should still exist")
|
||||
.updated_at,
|
||||
pending_updated_at
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn metadata_irrelevant_events_fall_back_to_upsert_when_thread_missing() -> std::io::Result<()>
|
||||
{
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
|
||||
let state_db = StateRuntime::init(home.path().to_path_buf(), config.model_provider_id.clone())
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let thread_id = ThreadId::new();
|
||||
let rollout_path = home.path().join("rollout.jsonl");
|
||||
let builder = ThreadMetadataBuilder::new(
|
||||
thread_id,
|
||||
rollout_path.clone(),
|
||||
Utc::now(),
|
||||
SessionSource::Cli,
|
||||
);
|
||||
let items = vec![RolloutItem::EventMsg(EventMsg::AgentMessage(
|
||||
AgentMessageEvent {
|
||||
message: "assistant text".to_string(),
|
||||
phase: None,
|
||||
memory_citation: None,
|
||||
},
|
||||
))];
|
||||
|
||||
let mut thread_updated_at_touch = ThreadUpdatedAtTouch::default();
|
||||
sync_thread_state_after_write(
|
||||
Some(state_db.as_ref()),
|
||||
rollout_path.as_path(),
|
||||
Some(&builder),
|
||||
items.as_slice(),
|
||||
config.model_provider_id.as_str(),
|
||||
/*new_thread_memory_mode*/ None,
|
||||
&mut thread_updated_at_touch,
|
||||
)
|
||||
.await;
|
||||
|
||||
let thread = state_db
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("thread should load after fallback")
|
||||
.expect("thread should be inserted after fallback");
|
||||
assert_eq!(thread.id, thread_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_threads_db_disabled_does_not_skip_paginated_items() -> std::io::Result<()> {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
|
||||
@@ -19,6 +19,7 @@ codex-git-utils = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-rollout = { workspace = true }
|
||||
codex-state = { workspace = true }
|
||||
codex-utils-path = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
# Thread Store
|
||||
|
||||
`codex-thread-store` is the storage boundary for Codex threads. It defines the
|
||||
`ThreadStore` trait plus local and in-memory implementations. Other storage
|
||||
implementations may live outside this repository.
|
||||
|
||||
## Responsibilities
|
||||
|
||||
- `ThreadStore::append_items` is the raw canonical history append API. It does
|
||||
not infer metadata from item contents.
|
||||
- `ThreadStore::update_thread_metadata` is the only thread metadata write API.
|
||||
It accepts a single literal metadata patch shape, regardless of whether the
|
||||
caller is applying a user/API mutation or facts derived above the store from
|
||||
appended history.
|
||||
- `LiveThread` is the preferred API for active session persistence. It owns a
|
||||
per-thread metadata sync helper, applies the rollout persistence policy,
|
||||
appends canonical history, and then sends metadata patches through
|
||||
`ThreadStore::update_thread_metadata`.
|
||||
- `ThreadManager` routes metadata mutations for loaded and cold threads through
|
||||
one entrypoint. Loaded threads use their `LiveThread`; cold threads go
|
||||
directly to the store.
|
||||
- `LocalThreadStore` persists history through `codex-rollout` JSONL files and
|
||||
persists queryable metadata through the SQLite state database when available.
|
||||
Local explicit metadata mutations also maintain JSONL/name-index compatibility
|
||||
so reading old or SQLite-less local storage keeps working.
|
||||
- `RolloutRecorder` is the local JSONL writer. It writes already-canonical
|
||||
items for `ThreadStore::append_items`; it no longer decides metadata updates
|
||||
for live thread-store appends.
|
||||
- `core/session` creates or resumes `LiveThread` handles and does not need to
|
||||
know whether persistence is backed by local files or another store.
|
||||
|
||||
## Direction
|
||||
|
||||
New metadata observation semantics should live above `ThreadStore`. Stores
|
||||
persist explicit metadata fields, but raw history appends remain history-only.
|
||||
@@ -22,6 +22,7 @@ use crate::ReadThreadParams;
|
||||
use crate::ResumeThreadParams;
|
||||
use crate::StoredThread;
|
||||
use crate::StoredThreadHistory;
|
||||
use crate::ThreadMetadataPatch;
|
||||
use crate::ThreadPage;
|
||||
use crate::ThreadStore;
|
||||
use crate::ThreadStoreError;
|
||||
@@ -127,6 +128,7 @@ struct InMemoryThreadStoreState {
|
||||
calls: InMemoryThreadStoreCalls,
|
||||
created_threads: HashMap<ThreadId, CreateThreadParams>,
|
||||
histories: HashMap<ThreadId, Vec<RolloutItem>>,
|
||||
metadata_updates: HashMap<ThreadId, ThreadMetadataPatch>,
|
||||
names: HashMap<ThreadId, Option<String>>,
|
||||
rollout_paths: HashMap<PathBuf, ThreadId>,
|
||||
}
|
||||
@@ -271,9 +273,14 @@ impl ThreadStore for InMemoryThreadStore {
|
||||
) -> ThreadStoreResult<StoredThread> {
|
||||
let mut state = self.state.lock().await;
|
||||
state.calls.update_thread_metadata += 1;
|
||||
if let Some(name) = params.patch.name {
|
||||
state.names.insert(params.thread_id, Some(name));
|
||||
if let Some(name) = params.patch.name.clone() {
|
||||
state.names.insert(params.thread_id, name);
|
||||
}
|
||||
state
|
||||
.metadata_updates
|
||||
.entry(params.thread_id)
|
||||
.or_default()
|
||||
.merge(params.patch);
|
||||
stored_thread_from_state(&state, params.thread_id, /*include_history*/ false)
|
||||
}
|
||||
|
||||
@@ -307,6 +314,7 @@ fn stored_thread_from_state(
|
||||
items: history_items.clone(),
|
||||
});
|
||||
let name = state.names.get(&thread_id).cloned().flatten();
|
||||
let metadata = state.metadata_updates.get(&thread_id);
|
||||
let rollout_path = state
|
||||
.rollout_paths
|
||||
.iter()
|
||||
@@ -316,28 +324,65 @@ fn stored_thread_from_state(
|
||||
|
||||
Ok(StoredThread {
|
||||
thread_id,
|
||||
rollout_path,
|
||||
rollout_path: metadata
|
||||
.and_then(|metadata| metadata.rollout_path.clone())
|
||||
.or(rollout_path),
|
||||
forked_from_id: created.forked_from_id,
|
||||
preview: String::new(),
|
||||
preview: metadata
|
||||
.and_then(|metadata| metadata.preview.clone())
|
||||
.unwrap_or_default(),
|
||||
name,
|
||||
model_provider: "test".to_string(),
|
||||
model: None,
|
||||
reasoning_effort: None,
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
model_provider: metadata
|
||||
.and_then(|metadata| metadata.model_provider.clone())
|
||||
.unwrap_or_else(|| "test".to_string()),
|
||||
model: metadata.and_then(|metadata| metadata.model.clone()),
|
||||
reasoning_effort: metadata.and_then(|metadata| metadata.reasoning_effort),
|
||||
created_at: metadata
|
||||
.and_then(|metadata| metadata.created_at)
|
||||
.unwrap_or_else(Utc::now),
|
||||
updated_at: metadata
|
||||
.and_then(|metadata| metadata.updated_at)
|
||||
.unwrap_or_else(Utc::now),
|
||||
archived_at: None,
|
||||
cwd: PathBuf::new(),
|
||||
cli_version: "test".to_string(),
|
||||
source: created.source.clone(),
|
||||
thread_source: created.thread_source,
|
||||
agent_nickname: None,
|
||||
agent_role: None,
|
||||
agent_path: None,
|
||||
git_info: None,
|
||||
approval_mode: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::new_read_only_policy(),
|
||||
token_usage: None,
|
||||
first_user_message: None,
|
||||
cwd: metadata
|
||||
.and_then(|metadata| metadata.cwd.clone())
|
||||
.unwrap_or_default(),
|
||||
cli_version: metadata
|
||||
.and_then(|metadata| metadata.cli_version.clone())
|
||||
.unwrap_or_else(|| "test".to_string()),
|
||||
source: metadata
|
||||
.and_then(|metadata| metadata.source.clone())
|
||||
.unwrap_or_else(|| created.source.clone()),
|
||||
thread_source: metadata
|
||||
.and_then(|metadata| metadata.thread_source)
|
||||
.unwrap_or(created.thread_source),
|
||||
agent_nickname: metadata.and_then(|metadata| metadata.agent_nickname.clone().flatten()),
|
||||
agent_role: metadata.and_then(|metadata| metadata.agent_role.clone().flatten()),
|
||||
agent_path: metadata.and_then(|metadata| metadata.agent_path.clone().flatten()),
|
||||
git_info: metadata.and_then(git_info_from_patch),
|
||||
approval_mode: metadata
|
||||
.and_then(|metadata| metadata.approval_mode)
|
||||
.unwrap_or(AskForApproval::Never),
|
||||
sandbox_policy: metadata
|
||||
.and_then(|metadata| metadata.sandbox_policy.clone())
|
||||
.unwrap_or_else(SandboxPolicy::new_read_only_policy),
|
||||
token_usage: metadata.and_then(|metadata| metadata.token_usage.clone()),
|
||||
first_user_message: metadata.and_then(|metadata| metadata.first_user_message.clone()),
|
||||
history,
|
||||
})
|
||||
}
|
||||
|
||||
fn git_info_from_patch(patch: &ThreadMetadataPatch) -> Option<codex_protocol::protocol::GitInfo> {
|
||||
let git_info = patch.git_info.as_ref()?;
|
||||
let sha = git_info.sha.clone().flatten();
|
||||
let branch = git_info.branch.clone().flatten();
|
||||
let origin_url = git_info.origin_url.clone().flatten();
|
||||
if sha.is_none() && branch.is_none() && origin_url.is_none() {
|
||||
return None;
|
||||
}
|
||||
Some(codex_protocol::protocol::GitInfo {
|
||||
commit_hash: sha.as_deref().map(codex_git_utils::GitSha::new),
|
||||
branch,
|
||||
repository_url: origin_url,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ mod in_memory;
|
||||
mod live_thread;
|
||||
mod local;
|
||||
mod store;
|
||||
mod thread_metadata_sync;
|
||||
mod types;
|
||||
|
||||
pub use error::ThreadStoreError;
|
||||
@@ -22,6 +23,7 @@ pub use local::LocalThreadStoreConfig;
|
||||
pub use store::ThreadStore;
|
||||
pub use types::AppendThreadItemsParams;
|
||||
pub use types::ArchiveThreadParams;
|
||||
pub use types::ClearableField;
|
||||
pub use types::CreateThreadParams;
|
||||
pub use types::GitInfoPatch;
|
||||
pub use types::ItemPage;
|
||||
@@ -29,7 +31,6 @@ pub use types::ListItemsParams;
|
||||
pub use types::ListThreadsParams;
|
||||
pub use types::ListTurnsParams;
|
||||
pub use types::LoadThreadHistoryParams;
|
||||
pub use types::OptionalStringPatch;
|
||||
pub use types::ReadThreadByRolloutPathParams;
|
||||
pub use types::ReadThreadParams;
|
||||
pub use types::ResumeThreadParams;
|
||||
|
||||
@@ -4,6 +4,9 @@ use std::sync::Arc;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::ThreadMemoryMode;
|
||||
use codex_rollout::EventPersistenceMode;
|
||||
use codex_rollout::persisted_rollout_items;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::AppendThreadItemsParams;
|
||||
@@ -14,10 +17,12 @@ use crate::ReadThreadParams;
|
||||
use crate::ResumeThreadParams;
|
||||
use crate::StoredThread;
|
||||
use crate::StoredThreadHistory;
|
||||
use crate::ThreadEventPersistenceMode;
|
||||
use crate::ThreadMetadataPatch;
|
||||
use crate::ThreadStore;
|
||||
use crate::ThreadStoreResult;
|
||||
use crate::UpdateThreadMetadataParams;
|
||||
use crate::thread_metadata_sync::ThreadMetadataSync;
|
||||
|
||||
/// Handle for an active thread's persistence lifecycle.
|
||||
///
|
||||
@@ -28,6 +33,8 @@ use crate::UpdateThreadMetadataParams;
|
||||
pub struct LiveThread {
|
||||
thread_id: ThreadId,
|
||||
thread_store: Arc<dyn ThreadStore>,
|
||||
event_persistence_mode: EventPersistenceMode,
|
||||
metadata_sync: Arc<Mutex<ThreadMetadataSync>>,
|
||||
}
|
||||
|
||||
/// Owns a live thread while session initialization is still fallible.
|
||||
@@ -85,43 +92,96 @@ impl LiveThread {
|
||||
params: CreateThreadParams,
|
||||
) -> ThreadStoreResult<Self> {
|
||||
let thread_id = params.thread_id;
|
||||
let event_persistence_mode = event_persistence_mode(params.event_persistence_mode);
|
||||
let metadata_sync = ThreadMetadataSync::for_create(¶ms).await;
|
||||
thread_store.create_thread(params).await?;
|
||||
Ok(Self {
|
||||
thread_id,
|
||||
thread_store,
|
||||
event_persistence_mode,
|
||||
metadata_sync: Arc::new(Mutex::new(metadata_sync)),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn resume(
|
||||
thread_store: Arc<dyn ThreadStore>,
|
||||
params: ResumeThreadParams,
|
||||
mut params: ResumeThreadParams,
|
||||
) -> ThreadStoreResult<Self> {
|
||||
let thread_id = params.thread_id;
|
||||
thread_store.resume_thread(params).await?;
|
||||
let event_persistence_mode = event_persistence_mode(params.event_persistence_mode);
|
||||
let should_load_history = params.history.is_none();
|
||||
let include_archived = params.include_archived;
|
||||
thread_store.resume_thread(params.clone()).await?;
|
||||
if should_load_history {
|
||||
match thread_store
|
||||
.load_history(LoadThreadHistoryParams {
|
||||
thread_id,
|
||||
include_archived,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(history) => params.history = Some(history.items),
|
||||
Err(err) => {
|
||||
let _ = thread_store.discard_thread(thread_id).await;
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
let metadata_sync = ThreadMetadataSync::for_resume(¶ms);
|
||||
Ok(Self {
|
||||
thread_id,
|
||||
thread_store,
|
||||
event_persistence_mode,
|
||||
metadata_sync: Arc::new(Mutex::new(metadata_sync)),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn append_items(&self, items: &[RolloutItem]) -> ThreadStoreResult<()> {
|
||||
let canonical_items = persisted_rollout_items(items, self.event_persistence_mode);
|
||||
if canonical_items.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.thread_store
|
||||
.append_items(AppendThreadItemsParams {
|
||||
thread_id: self.thread_id,
|
||||
items: items.to_vec(),
|
||||
items: canonical_items.clone(),
|
||||
})
|
||||
.await?;
|
||||
let update = self
|
||||
.metadata_sync
|
||||
.lock()
|
||||
.await
|
||||
.observe_appended_items(canonical_items.as_slice());
|
||||
if let Some(update) = update {
|
||||
self.thread_store
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id: self.thread_id,
|
||||
patch: update.patch.clone(),
|
||||
include_archived: true,
|
||||
})
|
||||
.await?;
|
||||
self.metadata_sync
|
||||
.lock()
|
||||
.await
|
||||
.mark_pending_update_applied(&update);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn persist(&self) -> ThreadStoreResult<()> {
|
||||
self.thread_store.persist_thread(self.thread_id).await
|
||||
self.thread_store.persist_thread(self.thread_id).await?;
|
||||
self.flush_pending_metadata_update().await
|
||||
}
|
||||
|
||||
pub async fn flush(&self) -> ThreadStoreResult<()> {
|
||||
self.thread_store.flush_thread(self.thread_id).await
|
||||
self.thread_store.flush_thread(self.thread_id).await?;
|
||||
self.flush_pending_metadata_update_for_existing_history()
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> ThreadStoreResult<()> {
|
||||
self.flush_pending_metadata_update_for_existing_history()
|
||||
.await?;
|
||||
self.thread_store.shutdown_thread(self.thread_id).await
|
||||
}
|
||||
|
||||
@@ -160,6 +220,7 @@ impl LiveThread {
|
||||
mode: ThreadMemoryMode,
|
||||
include_archived: bool,
|
||||
) -> ThreadStoreResult<()> {
|
||||
self.flush_pending_metadata_update().await?;
|
||||
self.thread_store
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id: self.thread_id,
|
||||
@@ -178,6 +239,7 @@ impl LiveThread {
|
||||
patch: ThreadMetadataPatch,
|
||||
include_archived: bool,
|
||||
) -> ThreadStoreResult<StoredThread> {
|
||||
self.flush_pending_metadata_update().await?;
|
||||
self.thread_store
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id: self.thread_id,
|
||||
@@ -203,4 +265,46 @@ impl LiveThread {
|
||||
.await
|
||||
.map(Some)
|
||||
}
|
||||
|
||||
async fn flush_pending_metadata_update(&self) -> ThreadStoreResult<()> {
|
||||
let update = self.metadata_sync.lock().await.take_pending_update();
|
||||
self.apply_pending_metadata_update(update).await
|
||||
}
|
||||
|
||||
async fn flush_pending_metadata_update_for_existing_history(&self) -> ThreadStoreResult<()> {
|
||||
let update = self
|
||||
.metadata_sync
|
||||
.lock()
|
||||
.await
|
||||
.take_pending_update_for_existing_history();
|
||||
self.apply_pending_metadata_update(update).await
|
||||
}
|
||||
|
||||
async fn apply_pending_metadata_update(
|
||||
&self,
|
||||
update: Option<crate::thread_metadata_sync::PendingThreadMetadataPatch>,
|
||||
) -> ThreadStoreResult<()> {
|
||||
let Some(update) = update else {
|
||||
return Ok(());
|
||||
};
|
||||
self.thread_store
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id: self.thread_id,
|
||||
patch: update.patch.clone(),
|
||||
include_archived: true,
|
||||
})
|
||||
.await?;
|
||||
self.metadata_sync
|
||||
.lock()
|
||||
.await
|
||||
.mark_pending_update_applied(&update);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn event_persistence_mode(mode: ThreadEventPersistenceMode) -> EventPersistenceMode {
|
||||
match mode {
|
||||
ThreadEventPersistenceMode::Limited => EventPersistenceMode::Limited,
|
||||
ThreadEventPersistenceMode::Extended => EventPersistenceMode::Extended,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
use super::LocalThreadStore;
|
||||
use crate::CreateThreadParams;
|
||||
use crate::ThreadEventPersistenceMode;
|
||||
use crate::ThreadStoreError;
|
||||
use crate::ThreadStoreResult;
|
||||
use codex_protocol::protocol::ThreadMemoryMode;
|
||||
use codex_rollout::EventPersistenceMode;
|
||||
use codex_rollout::RolloutConfig;
|
||||
use codex_rollout::RolloutRecorder;
|
||||
use codex_rollout::RolloutRecorderParams;
|
||||
@@ -27,7 +25,6 @@ pub(super) async fn create_thread(
|
||||
model_provider_id: params.metadata.model_provider.clone(),
|
||||
generate_memories: matches!(params.metadata.memory_mode, ThreadMemoryMode::Enabled),
|
||||
};
|
||||
let state_db_ctx = store.state_db().await;
|
||||
let recorder = RolloutRecorder::new(
|
||||
&config,
|
||||
RolloutRecorderParams::new(
|
||||
@@ -37,10 +34,7 @@ pub(super) async fn create_thread(
|
||||
params.thread_source,
|
||||
params.base_instructions,
|
||||
params.dynamic_tools,
|
||||
event_persistence_mode(params.event_persistence_mode),
|
||||
),
|
||||
state_db_ctx,
|
||||
/*state_builder*/ None,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| ThreadStoreError::Internal {
|
||||
@@ -49,10 +43,3 @@ pub(super) async fn create_thread(
|
||||
|
||||
Ok(recorder)
|
||||
}
|
||||
|
||||
pub(super) fn event_persistence_mode(mode: ThreadEventPersistenceMode) -> EventPersistenceMode {
|
||||
match mode {
|
||||
ThreadEventPersistenceMode::Limited => EventPersistenceMode::Limited,
|
||||
ThreadEventPersistenceMode::Extended => EventPersistenceMode::Extended,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use codex_protocol::protocol::ThreadMemoryMode;
|
||||
use codex_rollout::RolloutConfig;
|
||||
use codex_rollout::RolloutRecorder;
|
||||
use codex_rollout::RolloutRecorderParams;
|
||||
use codex_rollout::builder_from_items;
|
||||
use tracing::warn;
|
||||
|
||||
use super::LocalThreadStore;
|
||||
use super::create_thread;
|
||||
@@ -31,8 +31,8 @@ pub(super) async fn resume_thread(
|
||||
params: ResumeThreadParams,
|
||||
) -> ThreadStoreResult<()> {
|
||||
store.ensure_live_recorder_absent(params.thread_id).await?;
|
||||
let (rollout_path, history) = match (params.rollout_path, params.history) {
|
||||
(Some(rollout_path), history) => (rollout_path, history),
|
||||
let rollout_path = match (params.rollout_path, params.history) {
|
||||
(Some(rollout_path), _history) => rollout_path,
|
||||
(None, history) => {
|
||||
let thread = super::read_thread::read_thread(
|
||||
store,
|
||||
@@ -43,20 +43,14 @@ pub(super) async fn resume_thread(
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let rollout_path = thread
|
||||
|
||||
thread
|
||||
.rollout_path
|
||||
.ok_or_else(|| ThreadStoreError::Internal {
|
||||
message: format!("thread {} does not have a rollout path", params.thread_id),
|
||||
})?;
|
||||
(
|
||||
rollout_path,
|
||||
history.or_else(|| thread.history.map(|history| history.items)),
|
||||
)
|
||||
})?
|
||||
}
|
||||
};
|
||||
let state_builder = history
|
||||
.as_deref()
|
||||
.and_then(|items| builder_from_items(items, rollout_path.as_path()));
|
||||
let cwd = params
|
||||
.metadata
|
||||
.cwd
|
||||
@@ -71,20 +65,11 @@ pub(super) async fn resume_thread(
|
||||
model_provider_id: params.metadata.model_provider.clone(),
|
||||
generate_memories: matches!(params.metadata.memory_mode, ThreadMemoryMode::Enabled),
|
||||
};
|
||||
let state_db_ctx = store.state_db().await;
|
||||
let recorder = RolloutRecorder::new(
|
||||
&config,
|
||||
RolloutRecorderParams::resume(
|
||||
rollout_path,
|
||||
create_thread::event_persistence_mode(params.event_persistence_mode),
|
||||
),
|
||||
state_db_ctx,
|
||||
state_builder,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| ThreadStoreError::Internal {
|
||||
message: format!("failed to resume local thread recorder: {err}"),
|
||||
})?;
|
||||
let recorder = RolloutRecorder::new(&config, RolloutRecorderParams::resume(rollout_path))
|
||||
.await
|
||||
.map_err(|err| ThreadStoreError::Internal {
|
||||
message: format!("failed to resume local thread recorder: {err}"),
|
||||
})?;
|
||||
store.insert_live_recorder(params.thread_id, recorder).await
|
||||
}
|
||||
|
||||
@@ -92,12 +77,14 @@ pub(super) async fn append_items(
|
||||
store: &LocalThreadStore,
|
||||
params: AppendThreadItemsParams,
|
||||
) -> ThreadStoreResult<()> {
|
||||
store
|
||||
.live_recorder(params.thread_id)
|
||||
.await?
|
||||
.record_items(params.items.as_slice())
|
||||
let recorder = store.live_recorder(params.thread_id).await?;
|
||||
recorder
|
||||
.record_canonical_items(params.items.as_slice())
|
||||
.await
|
||||
.map_err(thread_store_io_error)
|
||||
.map_err(thread_store_io_error)?;
|
||||
// LiveThread applies metadata immediately after append_items returns. Wait for the local
|
||||
// writer so SQLite never gets ahead of JSONL for accepted live appends.
|
||||
recorder.flush().await.map_err(thread_store_io_error)
|
||||
}
|
||||
|
||||
pub(super) async fn persist_thread(
|
||||
@@ -109,7 +96,8 @@ pub(super) async fn persist_thread(
|
||||
.await?
|
||||
.persist()
|
||||
.await
|
||||
.map_err(thread_store_io_error)
|
||||
.map_err(thread_store_io_error)?;
|
||||
sync_materialized_rollout_path(store, thread_id).await
|
||||
}
|
||||
|
||||
pub(super) async fn flush_thread(
|
||||
@@ -121,7 +109,8 @@ pub(super) async fn flush_thread(
|
||||
.await?
|
||||
.flush()
|
||||
.await
|
||||
.map_err(thread_store_io_error)
|
||||
.map_err(thread_store_io_error)?;
|
||||
sync_materialized_rollout_path(store, thread_id).await
|
||||
}
|
||||
|
||||
pub(super) async fn shutdown_thread(
|
||||
@@ -130,6 +119,7 @@ pub(super) async fn shutdown_thread(
|
||||
) -> ThreadStoreResult<()> {
|
||||
let recorder = store.live_recorder(thread_id).await?;
|
||||
recorder.shutdown().await.map_err(thread_store_io_error)?;
|
||||
sync_materialized_rollout_path(store, thread_id).await?;
|
||||
store.live_recorders.lock().await.remove(&thread_id);
|
||||
Ok(())
|
||||
}
|
||||
@@ -161,6 +151,49 @@ pub(super) async fn rollout_path(
|
||||
.to_path_buf())
|
||||
}
|
||||
|
||||
async fn sync_materialized_rollout_path(
|
||||
store: &LocalThreadStore,
|
||||
thread_id: ThreadId,
|
||||
) -> ThreadStoreResult<()> {
|
||||
let rollout_path = rollout_path(store, thread_id).await?;
|
||||
if !tokio::fs::try_exists(rollout_path.as_path())
|
||||
.await
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
let Some(state_db) = store.state_db().await else {
|
||||
return Ok(());
|
||||
};
|
||||
let result: ThreadStoreResult<()> = async {
|
||||
let Some(mut metadata) =
|
||||
state_db
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.map_err(|err| ThreadStoreError::Internal {
|
||||
message: format!("failed to read thread metadata for {thread_id}: {err}"),
|
||||
})?
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
if metadata.rollout_path != rollout_path {
|
||||
metadata.rollout_path = rollout_path;
|
||||
state_db
|
||||
.upsert_thread(&metadata)
|
||||
.await
|
||||
.map_err(|err| ThreadStoreError::Internal {
|
||||
message: format!("failed to update thread metadata for {thread_id}: {err}"),
|
||||
})?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
if let Err(err) = result {
|
||||
warn!("failed to sync materialized rollout path for thread {thread_id}: {err}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn thread_store_io_error(err: std::io::Error) -> ThreadStoreError {
|
||||
ThreadStoreError::Internal {
|
||||
message: err.to_string(),
|
||||
|
||||
@@ -37,6 +37,18 @@ use crate::ThreadStoreResult;
|
||||
use crate::UpdateThreadMetadataParams;
|
||||
|
||||
/// Local filesystem/SQLite-backed implementation of [`ThreadStore`].
|
||||
///
|
||||
/// Local storage has two compatibility surfaces. Rollout JSONL files are the
|
||||
/// durable replay format and remain readable without SQLite, including older
|
||||
/// files that encode metadata in `SessionMeta` items and name-index entries.
|
||||
/// The SQLite state DB, when available, is the queryable metadata index used by
|
||||
/// list/read paths for fast lookup.
|
||||
///
|
||||
/// Live appends still write canonical JSONL history, but append-derived
|
||||
/// metadata is observed above the store and applied through
|
||||
/// [`ThreadStore::update_thread_metadata`]. This implementation applies that
|
||||
/// patch literally to SQLite while keeping the JSONL/name-index compatibility
|
||||
/// behavior needed for SQLite-less reads, repair, and old local rollout files.
|
||||
#[derive(Clone)]
|
||||
pub struct LocalThreadStore {
|
||||
pub(super) config: LocalThreadStoreConfig,
|
||||
@@ -270,6 +282,8 @@ impl ThreadStore for LocalThreadStore {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::models::BaseInstructions;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
@@ -280,6 +294,7 @@ mod tests {
|
||||
use tempfile::TempDir;
|
||||
|
||||
use super::*;
|
||||
use crate::LiveThread;
|
||||
use crate::ThreadEventPersistenceMode;
|
||||
use crate::ThreadPersistenceMetadata;
|
||||
use crate::local::test_support::test_config;
|
||||
@@ -335,6 +350,267 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn raw_append_items_does_not_update_sqlite_metadata() {
|
||||
// This pins the ThreadStore contract: raw appends are history-only. Callers that need
|
||||
// metadata updates must use LiveThread or call update_thread_metadata explicitly.
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
config.sqlite_home.clone(),
|
||||
config.default_model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let store = LocalThreadStore::new(config, Some(runtime.clone()));
|
||||
let thread_id = ThreadId::default();
|
||||
|
||||
store
|
||||
.create_thread(create_thread_params(thread_id))
|
||||
.await
|
||||
.expect("create live thread");
|
||||
store
|
||||
.append_items(AppendThreadItemsParams {
|
||||
thread_id,
|
||||
items: vec![user_message_item("raw append")],
|
||||
})
|
||||
.await
|
||||
.expect("append raw item");
|
||||
store.flush_thread(thread_id).await.expect("flush thread");
|
||||
|
||||
assert_eq!(
|
||||
runtime
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("sqlite metadata read"),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn live_thread_observes_appended_items_into_sqlite_metadata() {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
config.sqlite_home.clone(),
|
||||
config.default_model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let store = Arc::new(LocalThreadStore::new(config, Some(runtime.clone())));
|
||||
let thread_id = ThreadId::default();
|
||||
let live_thread = LiveThread::create(store.clone(), create_thread_params(thread_id))
|
||||
.await
|
||||
.expect("create live thread");
|
||||
|
||||
live_thread
|
||||
.append_items(&[user_message_item("observed append")])
|
||||
.await
|
||||
.expect("append observed item");
|
||||
live_thread.flush().await.expect("flush thread");
|
||||
|
||||
let metadata = runtime
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("sqlite metadata read")
|
||||
.expect("sqlite metadata");
|
||||
assert_eq!(
|
||||
metadata.first_user_message.as_deref(),
|
||||
Some("observed append")
|
||||
);
|
||||
assert_eq!(metadata.preview.as_deref(), Some("observed append"));
|
||||
assert_eq!(metadata.title, "observed append");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn live_thread_shutdown_does_not_materialize_empty_thread_metadata() {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
config.sqlite_home.clone(),
|
||||
config.default_model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let store = Arc::new(LocalThreadStore::new(config, Some(runtime.clone())));
|
||||
let thread_id = ThreadId::default();
|
||||
let live_thread = LiveThread::create(store.clone(), create_thread_params(thread_id))
|
||||
.await
|
||||
.expect("create live thread");
|
||||
let rollout_path = store
|
||||
.live_rollout_path(thread_id)
|
||||
.await
|
||||
.expect("live rollout path");
|
||||
|
||||
live_thread.shutdown().await.expect("shutdown thread");
|
||||
|
||||
assert!(
|
||||
!tokio::fs::try_exists(rollout_path.as_path())
|
||||
.await
|
||||
.expect("rollout path should be checkable")
|
||||
);
|
||||
assert_eq!(
|
||||
runtime
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("sqlite metadata read"),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn live_thread_shutdown_with_buffered_items_materializes_before_metadata_read() {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
config.sqlite_home.clone(),
|
||||
config.default_model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let store = Arc::new(LocalThreadStore::new(config, Some(runtime.clone())));
|
||||
let thread_id = ThreadId::default();
|
||||
let live_thread = LiveThread::create(store.clone(), create_thread_params(thread_id))
|
||||
.await
|
||||
.expect("create live thread");
|
||||
let rollout_path = store
|
||||
.live_rollout_path(thread_id)
|
||||
.await
|
||||
.expect("live rollout path");
|
||||
|
||||
live_thread
|
||||
.append_items(&[RolloutItem::EventMsg(EventMsg::TokenCount(
|
||||
codex_protocol::protocol::TokenCountEvent {
|
||||
info: None,
|
||||
rate_limits: None,
|
||||
},
|
||||
))])
|
||||
.await
|
||||
.expect("append metadata-only item");
|
||||
live_thread.shutdown().await.expect("shutdown thread");
|
||||
|
||||
assert!(
|
||||
tokio::fs::try_exists(rollout_path.as_path())
|
||||
.await
|
||||
.expect("rollout path should be checkable")
|
||||
);
|
||||
let metadata = runtime
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("sqlite metadata read")
|
||||
.expect("sqlite metadata");
|
||||
assert_eq!(metadata.rollout_path, rollout_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn live_thread_resume_loads_history_before_observing_metadata() {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
config.sqlite_home.clone(),
|
||||
config.default_model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let store = Arc::new(LocalThreadStore::new(config, Some(runtime.clone())));
|
||||
let uuid = uuid::Uuid::from_u128(401);
|
||||
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
|
||||
let rollout_path =
|
||||
write_session_file(home.path(), "2025-01-03T17-00-00", uuid).expect("session file");
|
||||
let live_thread = LiveThread::resume(
|
||||
store,
|
||||
ResumeThreadParams {
|
||||
thread_id,
|
||||
rollout_path: Some(rollout_path),
|
||||
history: None,
|
||||
include_archived: false,
|
||||
metadata: ThreadPersistenceMetadata {
|
||||
cwd: Some(home.path().to_path_buf()),
|
||||
model_provider: "different-provider".to_string(),
|
||||
memory_mode: ThreadMemoryMode::Enabled,
|
||||
},
|
||||
event_persistence_mode: ThreadEventPersistenceMode::Limited,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("resume live thread");
|
||||
|
||||
live_thread
|
||||
.append_items(&[user_message_item("new live append")])
|
||||
.await
|
||||
.expect("append after resume");
|
||||
|
||||
let metadata = runtime
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("sqlite metadata read")
|
||||
.expect("sqlite metadata");
|
||||
assert_eq!(
|
||||
metadata.created_at.to_rfc3339(),
|
||||
"2025-01-03T17:00:00+00:00"
|
||||
);
|
||||
assert_eq!(metadata.model_provider, "test-provider");
|
||||
assert_eq!(
|
||||
metadata.first_user_message.as_deref(),
|
||||
Some("Hello from user")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn live_thread_resume_loads_history_from_explicit_external_rollout_path() {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let external_home = TempDir::new().expect("external temp dir");
|
||||
let config = test_config(home.path());
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
config.sqlite_home.clone(),
|
||||
config.default_model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let store = Arc::new(LocalThreadStore::new(config, Some(runtime.clone())));
|
||||
let uuid = uuid::Uuid::from_u128(402);
|
||||
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
|
||||
let rollout_path = write_session_file(external_home.path(), "2025-01-03T17-30-00", uuid)
|
||||
.expect("external session file");
|
||||
let live_thread = LiveThread::resume(
|
||||
store,
|
||||
ResumeThreadParams {
|
||||
thread_id,
|
||||
rollout_path: Some(rollout_path),
|
||||
history: None,
|
||||
include_archived: false,
|
||||
metadata: ThreadPersistenceMetadata {
|
||||
cwd: Some(home.path().to_path_buf()),
|
||||
model_provider: "different-provider".to_string(),
|
||||
memory_mode: ThreadMemoryMode::Enabled,
|
||||
},
|
||||
event_persistence_mode: ThreadEventPersistenceMode::Limited,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("resume external live thread");
|
||||
|
||||
live_thread
|
||||
.append_items(&[user_message_item("new external append")])
|
||||
.await
|
||||
.expect("append after external resume");
|
||||
|
||||
let metadata = runtime
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("sqlite metadata read")
|
||||
.expect("sqlite metadata");
|
||||
assert_eq!(
|
||||
metadata.created_at.to_rfc3339(),
|
||||
"2025-01-03T17:30:00+00:00"
|
||||
);
|
||||
assert_eq!(metadata.model_provider, "test-provider");
|
||||
assert_eq!(
|
||||
metadata.first_user_message.as_deref(),
|
||||
Some("Hello from user")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_thread_rejects_missing_cwd() {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
|
||||
@@ -273,7 +273,8 @@ async fn stored_thread_from_sqlite_metadata(
|
||||
None => find_thread_name_by_id(store.config.codex_home.as_path(), &metadata.id)
|
||||
.await
|
||||
.ok()
|
||||
.flatten(),
|
||||
.flatten()
|
||||
.filter(|title| !title.trim().is_empty()),
|
||||
};
|
||||
let session_meta = read_session_meta_line(metadata.rollout_path.as_path())
|
||||
.await
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use chrono::Utc;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::protocol::GitInfo;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::ThreadMemoryMode;
|
||||
use codex_rollout::ARCHIVED_SESSIONS_SUBDIR;
|
||||
use codex_rollout::append_rollout_item_to_path;
|
||||
@@ -11,6 +13,7 @@ use codex_rollout::append_thread_name;
|
||||
use codex_rollout::find_archived_thread_path_by_id_str;
|
||||
use codex_rollout::find_thread_path_by_id_str;
|
||||
use codex_rollout::read_session_meta_line;
|
||||
use codex_state::ThreadMetadataBuilder;
|
||||
|
||||
use super::LocalThreadStore;
|
||||
use super::helpers::git_info_from_parts;
|
||||
@@ -18,6 +21,7 @@ use super::live_writer;
|
||||
use crate::GitInfoPatch;
|
||||
use crate::ReadThreadParams;
|
||||
use crate::StoredThread;
|
||||
use crate::ThreadMetadataPatch;
|
||||
use crate::ThreadStoreError;
|
||||
use crate::ThreadStoreResult;
|
||||
use crate::UpdateThreadMetadataParams;
|
||||
@@ -32,25 +36,35 @@ pub(super) async fn update_thread_metadata(
|
||||
store: &LocalThreadStore,
|
||||
params: UpdateThreadMetadataParams,
|
||||
) -> ThreadStoreResult<StoredThread> {
|
||||
let field_count = usize::from(params.patch.name.is_some())
|
||||
+ usize::from(params.patch.memory_mode.is_some())
|
||||
+ usize::from(params.patch.git_info.is_some());
|
||||
if field_count > 1 {
|
||||
return Err(ThreadStoreError::InvalidRequest {
|
||||
message: "local thread store applies one metadata field per patch in this slice"
|
||||
.to_string(),
|
||||
});
|
||||
let thread_id = params.thread_id;
|
||||
let patch = params.patch;
|
||||
if patch.is_empty() {
|
||||
return read_thread::read_thread(
|
||||
store,
|
||||
ReadThreadParams {
|
||||
thread_id,
|
||||
include_archived: params.include_archived,
|
||||
include_history: false,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let needs_rollout_compat = needs_rollout_compatibility_update(&patch);
|
||||
let updated =
|
||||
apply_metadata_update(store, thread_id, patch.clone(), params.include_archived).await?;
|
||||
if !needs_rollout_compat {
|
||||
return Ok(updated);
|
||||
}
|
||||
|
||||
let thread_id = params.thread_id;
|
||||
if live_writer::rollout_path(store, thread_id).await.is_ok() {
|
||||
live_writer::persist_thread(store, thread_id).await?;
|
||||
}
|
||||
let resolved_rollout_path =
|
||||
resolve_rollout_path(store, thread_id, params.include_archived).await?;
|
||||
let name = params.patch.name;
|
||||
let git_info = params.patch.git_info;
|
||||
if let Some(memory_mode) = params.patch.memory_mode {
|
||||
let name = patch.name;
|
||||
let git_info = patch.git_info;
|
||||
if let Some(memory_mode) = patch.memory_mode {
|
||||
apply_thread_memory_mode(resolved_rollout_path.path.as_path(), thread_id, memory_mode)
|
||||
.await?;
|
||||
}
|
||||
@@ -68,7 +82,7 @@ pub(super) async fn update_thread_metadata(
|
||||
.await;
|
||||
|
||||
if let Some(name) = name {
|
||||
apply_thread_name(store, thread_id, name).await?;
|
||||
apply_thread_name(store, thread_id, name.unwrap_or_default()).await?;
|
||||
}
|
||||
|
||||
let resolved_git_info = match git_info {
|
||||
@@ -150,6 +164,218 @@ pub(super) async fn update_thread_metadata(
|
||||
Ok(thread)
|
||||
}
|
||||
|
||||
async fn apply_metadata_update(
|
||||
store: &LocalThreadStore,
|
||||
thread_id: ThreadId,
|
||||
patch: ThreadMetadataPatch,
|
||||
include_archived: bool,
|
||||
) -> ThreadStoreResult<StoredThread> {
|
||||
let live_rollout_path = live_writer::rollout_path(store, thread_id).await.ok();
|
||||
let mut rollout_path = patch.rollout_path.clone().or(live_rollout_path);
|
||||
let mut rollout_path_archived = rollout_path
|
||||
.as_deref()
|
||||
.is_some_and(|path| rollout_path_is_archived(store, path));
|
||||
let state_db = store.state_db().await;
|
||||
let sqlite_write_result: ThreadStoreResult<()> = if let Some(state_db) = state_db.as_ref() {
|
||||
let patch = patch.clone();
|
||||
async {
|
||||
let existing =
|
||||
state_db
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.map_err(|err| ThreadStoreError::Internal {
|
||||
message: format!("failed to read thread metadata for {thread_id}: {err}"),
|
||||
})?;
|
||||
if existing.is_none() && rollout_path.is_none() {
|
||||
let resolved = resolve_rollout_path(store, thread_id, include_archived).await?;
|
||||
rollout_path_archived = resolved.archived;
|
||||
rollout_path = Some(resolved.path);
|
||||
}
|
||||
let mut metadata = existing.clone().unwrap_or_else(|| {
|
||||
let created_at = patch
|
||||
.created_at
|
||||
.or(patch.updated_at)
|
||||
.unwrap_or_else(Utc::now);
|
||||
let mut builder = ThreadMetadataBuilder::new(
|
||||
thread_id,
|
||||
rollout_path.clone().unwrap_or_default(),
|
||||
created_at,
|
||||
patch.source.clone().unwrap_or(SessionSource::Unknown),
|
||||
);
|
||||
builder.model_provider = patch.model_provider.clone();
|
||||
builder.thread_source = patch.thread_source.flatten();
|
||||
builder.agent_nickname = patch.agent_nickname.clone().flatten();
|
||||
builder.agent_role = patch.agent_role.clone().flatten();
|
||||
builder.agent_path = patch.agent_path.clone().flatten();
|
||||
builder.cwd = patch.cwd.clone().map(normalize_cwd).unwrap_or_default();
|
||||
builder.cli_version = patch.cli_version.clone();
|
||||
let mut metadata = builder.build(store.config.default_model_provider_id.as_str());
|
||||
if rollout_path_archived {
|
||||
metadata.archived_at = Some(metadata.updated_at);
|
||||
}
|
||||
metadata
|
||||
});
|
||||
if let Some(rollout_path) = rollout_path {
|
||||
metadata.rollout_path = rollout_path;
|
||||
}
|
||||
if let Some(preview) = patch.preview {
|
||||
metadata.preview = Some(preview);
|
||||
}
|
||||
if let Some(name) = patch.name {
|
||||
metadata.title = name.unwrap_or_default();
|
||||
}
|
||||
if let Some(title) = patch.title {
|
||||
metadata.title = title;
|
||||
}
|
||||
if let Some(model_provider) = patch.model_provider {
|
||||
metadata.model_provider = model_provider;
|
||||
}
|
||||
if let Some(model) = patch.model {
|
||||
metadata.model = Some(model);
|
||||
}
|
||||
if let Some(reasoning_effort) = patch.reasoning_effort {
|
||||
metadata.reasoning_effort = Some(reasoning_effort);
|
||||
}
|
||||
if let Some(created_at) = patch.created_at {
|
||||
metadata.created_at = created_at;
|
||||
}
|
||||
if let Some(updated_at) = patch.updated_at {
|
||||
metadata.updated_at = updated_at;
|
||||
}
|
||||
if let Some(source) = patch.source {
|
||||
metadata.source = enum_to_string(&source);
|
||||
}
|
||||
if let Some(thread_source) = patch.thread_source {
|
||||
metadata.thread_source = thread_source;
|
||||
}
|
||||
if let Some(agent_nickname) = patch.agent_nickname {
|
||||
metadata.agent_nickname = agent_nickname;
|
||||
}
|
||||
if let Some(agent_role) = patch.agent_role {
|
||||
metadata.agent_role = agent_role;
|
||||
}
|
||||
if let Some(agent_path) = patch.agent_path {
|
||||
metadata.agent_path = agent_path;
|
||||
}
|
||||
if let Some(cwd) = patch.cwd {
|
||||
metadata.cwd = normalize_cwd(cwd);
|
||||
}
|
||||
if let Some(cli_version) = patch.cli_version {
|
||||
metadata.cli_version = cli_version;
|
||||
}
|
||||
if let Some(approval_mode) = patch.approval_mode {
|
||||
metadata.approval_mode = enum_to_string(&approval_mode);
|
||||
}
|
||||
if let Some(sandbox_policy) = patch.sandbox_policy {
|
||||
metadata.sandbox_policy = enum_to_string(&sandbox_policy);
|
||||
}
|
||||
if let Some(token_usage) = patch.token_usage {
|
||||
metadata.tokens_used = token_usage.total_tokens.max(0);
|
||||
}
|
||||
if let Some(first_user_message) = patch.first_user_message {
|
||||
metadata.first_user_message = Some(first_user_message);
|
||||
}
|
||||
if let Some(git_info) = patch.git_info {
|
||||
let existing_git_info = git_info_from_parts(
|
||||
metadata.git_sha.clone(),
|
||||
metadata.git_branch.clone(),
|
||||
metadata.git_origin_url.clone(),
|
||||
);
|
||||
let (sha, branch, origin_url) = resolve_git_info_patch(existing_git_info, git_info);
|
||||
metadata.git_sha = sha;
|
||||
metadata.git_branch = branch;
|
||||
metadata.git_origin_url = origin_url;
|
||||
}
|
||||
state_db
|
||||
.upsert_thread(&metadata)
|
||||
.await
|
||||
.map_err(|err| ThreadStoreError::Internal {
|
||||
message: format!("failed to update thread metadata for {thread_id}: {err}"),
|
||||
})?;
|
||||
if let Some(memory_mode) = patch.memory_mode {
|
||||
state_db
|
||||
.set_thread_memory_mode(thread_id, memory_mode_as_str(memory_mode))
|
||||
.await
|
||||
.map_err(|err| ThreadStoreError::Internal {
|
||||
message: format!("failed to update memory mode for {thread_id}: {err}"),
|
||||
})?;
|
||||
}
|
||||
if let Some(dynamic_tools) = patch.dynamic_tools {
|
||||
state_db
|
||||
.persist_dynamic_tools(thread_id, Some(dynamic_tools.as_slice()))
|
||||
.await
|
||||
.map_err(|err| ThreadStoreError::Internal {
|
||||
message: format!("failed to update dynamic tools for {thread_id}: {err}"),
|
||||
})?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
.await
|
||||
} else {
|
||||
Ok(())
|
||||
};
|
||||
match (state_db.is_some(), sqlite_write_result) {
|
||||
(true, Ok(())) => {}
|
||||
(true, Err(err)) => return Err(err),
|
||||
(false, Ok(())) => {}
|
||||
(false, Err(err)) => return Err(err),
|
||||
}
|
||||
|
||||
read_thread::read_thread(
|
||||
store,
|
||||
ReadThreadParams {
|
||||
thread_id,
|
||||
include_archived,
|
||||
include_history: false,
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn needs_rollout_compatibility_update(patch: &ThreadMetadataPatch) -> bool {
|
||||
if patch.name.is_some() {
|
||||
return true;
|
||||
}
|
||||
if patch.memory_mode.is_none() && patch.git_info.is_none() {
|
||||
return false;
|
||||
}
|
||||
!has_observed_metadata_facts(patch)
|
||||
}
|
||||
|
||||
fn has_observed_metadata_facts(patch: &ThreadMetadataPatch) -> bool {
|
||||
patch.rollout_path.is_some()
|
||||
|| patch.preview.is_some()
|
||||
|| patch.title.is_some()
|
||||
|| patch.model_provider.is_some()
|
||||
|| patch.model.is_some()
|
||||
|| patch.reasoning_effort.is_some()
|
||||
|| patch.created_at.is_some()
|
||||
|| patch.source.is_some()
|
||||
|| patch.thread_source.is_some()
|
||||
|| patch.agent_nickname.is_some()
|
||||
|| patch.agent_role.is_some()
|
||||
|| patch.agent_path.is_some()
|
||||
|| patch.cwd.is_some()
|
||||
|| patch.cli_version.is_some()
|
||||
|| patch.approval_mode.is_some()
|
||||
|| patch.sandbox_policy.is_some()
|
||||
|| patch.token_usage.is_some()
|
||||
|| patch.first_user_message.is_some()
|
||||
|| patch.dynamic_tools.is_some()
|
||||
}
|
||||
|
||||
fn enum_to_string<T: serde::Serialize>(value: &T) -> String {
|
||||
match serde_json::to_value(value) {
|
||||
Ok(serde_json::Value::String(value)) => value,
|
||||
Ok(other) => other.to_string(),
|
||||
Err(_) => String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_cwd(cwd: PathBuf) -> PathBuf {
|
||||
codex_utils_path::normalize_for_path_comparison(cwd.as_path()).unwrap_or(cwd)
|
||||
}
|
||||
|
||||
async fn apply_thread_git_info(
|
||||
store: &LocalThreadStore,
|
||||
thread_id: ThreadId,
|
||||
@@ -363,10 +589,13 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::GitInfoPatch;
|
||||
use crate::ListThreadsParams;
|
||||
use crate::ResumeThreadParams;
|
||||
use crate::SortDirection;
|
||||
use crate::ThreadEventPersistenceMode;
|
||||
use crate::ThreadMetadataPatch;
|
||||
use crate::ThreadPersistenceMetadata;
|
||||
use crate::ThreadSortKey;
|
||||
use crate::ThreadStore;
|
||||
use crate::local::LocalThreadStore;
|
||||
use crate::local::test_support::test_config;
|
||||
@@ -385,7 +614,7 @@ mod tests {
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id,
|
||||
patch: ThreadMetadataPatch {
|
||||
name: Some("A sharper name".to_string()),
|
||||
name: Some(Some("A sharper name".to_string())),
|
||||
..Default::default()
|
||||
},
|
||||
include_archived: false,
|
||||
@@ -855,44 +1084,306 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn update_thread_metadata_rejects_multi_field_patch_without_partial_write() {
|
||||
async fn update_thread_metadata_applies_combined_explicit_patch() {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let store = LocalThreadStore::new(test_config(home.path()), /*state_db*/ None);
|
||||
let config = test_config(home.path());
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
home.path().to_path_buf(),
|
||||
config.default_model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let store = LocalThreadStore::new(config, Some(runtime.clone()));
|
||||
let uuid = Uuid::from_u128(305);
|
||||
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
|
||||
let path =
|
||||
write_session_file(home.path(), "2025-01-03T15-30-00", uuid).expect("session file");
|
||||
let original = std::fs::read_to_string(&path).expect("read rollout");
|
||||
|
||||
let err = store
|
||||
let thread = store
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id,
|
||||
patch: ThreadMetadataPatch {
|
||||
name: Some("Should not persist".to_string()),
|
||||
name: Some(Some("Combined metadata".to_string())),
|
||||
memory_mode: Some(ThreadMemoryMode::Disabled),
|
||||
git_info: Some(GitInfoPatch {
|
||||
branch: Some(Some("combined".to_string())),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
include_archived: false,
|
||||
})
|
||||
.await
|
||||
.expect_err("multi-field patch should fail");
|
||||
.expect("combined patch should apply");
|
||||
|
||||
assert!(matches!(err, ThreadStoreError::InvalidRequest { .. }));
|
||||
assert_eq!(thread.name.as_deref(), Some("Combined metadata"));
|
||||
assert_eq!(
|
||||
std::fs::read_to_string(&path).expect("read rollout"),
|
||||
original
|
||||
thread.git_info.expect("git info").branch.as_deref(),
|
||||
Some("combined")
|
||||
);
|
||||
let appended = last_rollout_item(path.as_path());
|
||||
assert_eq!(appended["type"], "session_meta");
|
||||
assert_eq!(appended["payload"]["memory_mode"], "disabled");
|
||||
assert_eq!(appended["payload"]["git"]["branch"], "combined");
|
||||
let latest_name = codex_rollout::find_thread_name_by_id(home.path(), &thread_id)
|
||||
.await
|
||||
.expect("find thread name");
|
||||
assert_eq!(latest_name, None);
|
||||
assert_eq!(latest_name.as_deref(), Some("Combined metadata"));
|
||||
let memory_mode = runtime
|
||||
.get_thread_memory_mode(thread_id)
|
||||
.await
|
||||
.expect("thread memory mode should be readable");
|
||||
assert_eq!(memory_mode.as_deref(), Some("disabled"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn metadata_patch_applies_title_over_existing_name() {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
home.path().to_path_buf(),
|
||||
config.default_model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let store = LocalThreadStore::new(config, Some(runtime));
|
||||
let uuid = Uuid::from_u128(306);
|
||||
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
|
||||
write_session_file(home.path(), "2025-01-03T15-45-00", uuid).expect("session file");
|
||||
|
||||
store
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id,
|
||||
patch: ThreadMetadataPatch {
|
||||
name: Some(Some("User chosen name".to_string())),
|
||||
..Default::default()
|
||||
},
|
||||
include_archived: false,
|
||||
})
|
||||
.await
|
||||
.expect("set explicit name");
|
||||
|
||||
let thread = store
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id,
|
||||
patch: ThreadMetadataPatch {
|
||||
title: Some("Derived first message".to_string()),
|
||||
preview: Some("Derived first message".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
include_archived: false,
|
||||
})
|
||||
.await
|
||||
.expect("apply observed metadata");
|
||||
|
||||
assert_eq!(thread.name.as_deref(), Some("Derived first message"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn metadata_patch_applies_latest_preview_and_first_user_message() {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
home.path().to_path_buf(),
|
||||
config.default_model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let store = LocalThreadStore::new(config, Some(runtime.clone()));
|
||||
let uuid = Uuid::from_u128(313);
|
||||
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
|
||||
write_session_file(home.path(), "2025-01-03T19-00-00", uuid).expect("session file");
|
||||
|
||||
store
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id,
|
||||
patch: ThreadMetadataPatch {
|
||||
preview: Some("Original preview".to_string()),
|
||||
first_user_message: Some("Original first message".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
include_archived: false,
|
||||
})
|
||||
.await
|
||||
.expect("set observed metadata");
|
||||
|
||||
let thread = store
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id,
|
||||
patch: ThreadMetadataPatch {
|
||||
preview: Some("Later preview".to_string()),
|
||||
first_user_message: Some("Later first message".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
include_archived: false,
|
||||
})
|
||||
.await
|
||||
.expect("apply later observed metadata");
|
||||
|
||||
assert_eq!(thread.preview, "Hello from user");
|
||||
assert_eq!(
|
||||
thread.first_user_message.as_deref(),
|
||||
Some("Hello from user")
|
||||
);
|
||||
let metadata = runtime
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("sqlite metadata read")
|
||||
.expect("sqlite metadata");
|
||||
assert_eq!(metadata.preview.as_deref(), Some("Later preview"));
|
||||
assert_eq!(
|
||||
metadata.first_user_message.as_deref(),
|
||||
Some("Later first message")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn observed_metadata_rejects_unknown_thread_without_rollout() {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
home.path().to_path_buf(),
|
||||
config.default_model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let store = LocalThreadStore::new(config, Some(runtime.clone()));
|
||||
let uuid = Uuid::from_u128(314);
|
||||
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
|
||||
|
||||
let err = store
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id,
|
||||
patch: ThreadMetadataPatch {
|
||||
preview: Some("phantom".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
include_archived: false,
|
||||
})
|
||||
.await
|
||||
.expect_err("metadata-only update should not create a missing thread");
|
||||
|
||||
assert!(matches!(
|
||||
err,
|
||||
ThreadStoreError::InvalidRequest { message }
|
||||
if message == format!("thread not found: {thread_id}")
|
||||
));
|
||||
let metadata = runtime
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("sqlite metadata read");
|
||||
assert!(metadata.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn update_thread_metadata_recreates_missing_archived_sqlite_row_as_archived() {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
let uuid = Uuid::from_u128(315);
|
||||
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
|
||||
write_archived_session_file(home.path(), "2025-01-03T19-30-00", uuid)
|
||||
.expect("archived session file");
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
home.path().to_path_buf(),
|
||||
config.default_model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let store = LocalThreadStore::new(config, Some(runtime.clone()));
|
||||
|
||||
let thread = store
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id,
|
||||
patch: ThreadMetadataPatch {
|
||||
preview: Some("Archived missing sqlite row".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
include_archived: true,
|
||||
})
|
||||
.await
|
||||
.expect("update archived thread without sqlite row");
|
||||
|
||||
assert!(thread.archived_at.is_some());
|
||||
assert!(
|
||||
runtime
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("get metadata")
|
||||
.expect("metadata")
|
||||
.archived_at
|
||||
.is_some()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn observed_metadata_normalizes_cwd_for_list_filters() {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
let runtime = codex_state::StateRuntime::init(
|
||||
home.path().to_path_buf(),
|
||||
config.default_model_provider_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let store = LocalThreadStore::new(config, Some(runtime.clone()));
|
||||
let uuid = Uuid::from_u128(316);
|
||||
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
|
||||
write_session_file(home.path(), "2025-01-03T20-00-00", uuid).expect("session file");
|
||||
let workspace = home.path().join("workspace");
|
||||
let child = workspace.join("child");
|
||||
std::fs::create_dir_all(child.as_path()).expect("create workspace");
|
||||
let unnormalized_cwd = child.join("..");
|
||||
let normalized_cwd = codex_utils_path::normalize_for_path_comparison(workspace.as_path())
|
||||
.expect("normalize cwd");
|
||||
|
||||
store
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id,
|
||||
patch: ThreadMetadataPatch {
|
||||
cwd: Some(unnormalized_cwd),
|
||||
preview: Some("cwd preview".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
include_archived: false,
|
||||
})
|
||||
.await
|
||||
.expect("update observed cwd");
|
||||
|
||||
let metadata = runtime
|
||||
.get_thread(thread_id)
|
||||
.await
|
||||
.expect("get metadata")
|
||||
.expect("metadata");
|
||||
assert_eq!(metadata.cwd, normalized_cwd);
|
||||
let page = store
|
||||
.list_threads(ListThreadsParams {
|
||||
page_size: 10,
|
||||
cursor: None,
|
||||
sort_key: ThreadSortKey::UpdatedAt,
|
||||
sort_direction: SortDirection::Desc,
|
||||
allowed_sources: Vec::new(),
|
||||
model_providers: Some(Vec::new()),
|
||||
cwd_filters: Some(vec![workspace]),
|
||||
archived: false,
|
||||
search_term: None,
|
||||
use_state_db_only: true,
|
||||
})
|
||||
.await
|
||||
.expect("list threads by cwd");
|
||||
assert_eq!(
|
||||
page.items
|
||||
.iter()
|
||||
.map(|thread| thread.thread_id)
|
||||
.collect::<Vec<_>>(),
|
||||
vec![thread_id]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn update_thread_metadata_keeps_archived_thread_archived_in_sqlite() {
|
||||
let home = TempDir::new().expect("temp dir");
|
||||
let config = test_config(home.path());
|
||||
let uuid = Uuid::from_u128(306);
|
||||
let uuid = Uuid::from_u128(307);
|
||||
let thread_id = ThreadId::from_string(&uuid.to_string()).expect("valid thread id");
|
||||
let archived_path = write_archived_session_file(home.path(), "2025-01-03T16-00-00", uuid)
|
||||
.expect("archived session file");
|
||||
@@ -931,7 +1422,7 @@ mod tests {
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id,
|
||||
patch: ThreadMetadataPatch {
|
||||
name: Some("Archived title".to_string()),
|
||||
name: Some(Some("Archived title".to_string())),
|
||||
..Default::default()
|
||||
},
|
||||
include_archived: true,
|
||||
@@ -996,7 +1487,7 @@ mod tests {
|
||||
.update_thread_metadata(UpdateThreadMetadataParams {
|
||||
thread_id,
|
||||
patch: ThreadMetadataPatch {
|
||||
name: Some("Live archived title".to_string()),
|
||||
name: Some(Some("Live archived title".to_string())),
|
||||
..Default::default()
|
||||
},
|
||||
include_archived: true,
|
||||
|
||||
@@ -33,7 +33,11 @@ pub trait ThreadStore: Any + Send + Sync {
|
||||
/// Reopens an existing thread for live appends.
|
||||
async fn resume_thread(&self, params: ResumeThreadParams) -> ThreadStoreResult<()>;
|
||||
|
||||
/// Appends items to a live thread.
|
||||
/// Appends canonical rollout items to a live thread.
|
||||
///
|
||||
/// This is the raw history API. It does not infer metadata from item contents. Callers that
|
||||
/// need metadata updates should call [`ThreadStore::update_thread_metadata`] with explicit
|
||||
/// metadata facts prepared above the store.
|
||||
async fn append_items(&self, params: AppendThreadItemsParams) -> ThreadStoreResult<()>;
|
||||
|
||||
/// Materializes the thread if persistence is lazy, then persists all queued items.
|
||||
@@ -86,7 +90,10 @@ pub trait ThreadStore: Any + Send + Sync {
|
||||
})
|
||||
}
|
||||
|
||||
/// Applies a mutable metadata patch and returns the updated thread.
|
||||
/// Applies a literal metadata patch and returns the updated thread.
|
||||
///
|
||||
/// Implementations should apply the supplied fields directly. Policy such as deciding whether
|
||||
/// an append-derived preview should be emitted belongs above the store.
|
||||
async fn update_thread_metadata(
|
||||
&self,
|
||||
params: UpdateThreadMetadataParams,
|
||||
|
||||
@@ -0,0 +1,580 @@
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use chrono::DateTime;
|
||||
use chrono::NaiveDateTime;
|
||||
use chrono::Utc;
|
||||
use codex_git_utils::collect_git_info;
|
||||
use codex_git_utils::get_git_repo_root;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::GitInfo;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::ThreadMemoryMode;
|
||||
use codex_protocol::protocol::USER_MESSAGE_BEGIN;
|
||||
use codex_protocol::protocol::UserMessageEvent;
|
||||
|
||||
use crate::CreateThreadParams;
|
||||
use crate::GitInfoPatch;
|
||||
use crate::ResumeThreadParams;
|
||||
use crate::ThreadMetadataPatch;
|
||||
|
||||
const IMAGE_ONLY_USER_MESSAGE_PLACEHOLDER: &str = "[Image]";
|
||||
#[cfg(not(test))]
|
||||
const THREAD_UPDATED_AT_TOUCH_INTERVAL: Duration = Duration::from_secs(5);
|
||||
#[cfg(test)]
|
||||
const THREAD_UPDATED_AT_TOUCH_INTERVAL: Duration = Duration::from_millis(50);
|
||||
|
||||
/// Live-thread helper that derives metadata updates from canonical rollout items.
|
||||
///
|
||||
/// Stores receive raw history plus explicit metadata patches. This helper keeps append-derived
|
||||
/// metadata observation in the live layer without owning persistence-policy filtering or making
|
||||
/// `append_items` infer metadata inside a `ThreadStore` implementation.
|
||||
pub(crate) struct ThreadMetadataSync {
|
||||
thread_id: ThreadId,
|
||||
cwd_seen: bool,
|
||||
preview_seen: bool,
|
||||
first_user_message_seen: bool,
|
||||
title_seen: bool,
|
||||
pending_update: Option<ThreadMetadataPatch>,
|
||||
pending_update_generation: u64,
|
||||
last_touch_persisted_at: Option<Instant>,
|
||||
defer_create_update_until_history_exists: bool,
|
||||
defer_resume_update_until_append: bool,
|
||||
}
|
||||
|
||||
pub(crate) struct PendingThreadMetadataPatch {
|
||||
pub(crate) patch: ThreadMetadataPatch,
|
||||
generation: u64,
|
||||
}
|
||||
|
||||
impl ThreadMetadataSync {
|
||||
pub(crate) async fn for_create(params: &CreateThreadParams) -> Self {
|
||||
let created_at = Utc::now();
|
||||
let cwd = params.metadata.cwd.clone().unwrap_or_default();
|
||||
let git_info = if get_git_repo_root(cwd.as_path()).is_some() {
|
||||
collect_git_info(cwd.as_path()).await.map(|info| GitInfo {
|
||||
commit_hash: info.commit_hash,
|
||||
branch: info.branch,
|
||||
repository_url: info.repository_url,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let dynamic_tools =
|
||||
(!params.dynamic_tools.is_empty()).then(|| params.dynamic_tools.clone());
|
||||
let update = ThreadMetadataPatch {
|
||||
model_provider: Some(params.metadata.model_provider.clone()),
|
||||
created_at: Some(created_at),
|
||||
updated_at: Some(created_at),
|
||||
source: Some(params.source.clone()),
|
||||
thread_source: Some(params.thread_source),
|
||||
agent_nickname: Some(params.source.get_nickname()),
|
||||
agent_role: Some(params.source.get_agent_role()),
|
||||
agent_path: Some(params.source.get_agent_path().map(Into::into)),
|
||||
cwd: Some(cwd.clone()),
|
||||
cli_version: Some(env!("CARGO_PKG_VERSION").to_string()),
|
||||
git_info: git_info.map(git_info_patch_from_observation),
|
||||
memory_mode: Some(params.metadata.memory_mode),
|
||||
dynamic_tools,
|
||||
..Default::default()
|
||||
};
|
||||
Self {
|
||||
thread_id: params.thread_id,
|
||||
cwd_seen: !cwd.as_os_str().is_empty(),
|
||||
preview_seen: false,
|
||||
first_user_message_seen: false,
|
||||
title_seen: false,
|
||||
pending_update: Some(update),
|
||||
pending_update_generation: 1,
|
||||
last_touch_persisted_at: None,
|
||||
defer_create_update_until_history_exists: true,
|
||||
defer_resume_update_until_append: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn for_resume(params: &ResumeThreadParams) -> Self {
|
||||
let mut sync = Self {
|
||||
thread_id: params.thread_id,
|
||||
cwd_seen: params
|
||||
.metadata
|
||||
.cwd
|
||||
.as_ref()
|
||||
.is_some_and(|cwd| !cwd.as_os_str().is_empty()),
|
||||
preview_seen: false,
|
||||
first_user_message_seen: false,
|
||||
title_seen: false,
|
||||
pending_update: None,
|
||||
pending_update_generation: 0,
|
||||
last_touch_persisted_at: None,
|
||||
defer_create_update_until_history_exists: false,
|
||||
defer_resume_update_until_append: false,
|
||||
};
|
||||
if let Some(history) = params.history.as_deref() {
|
||||
let update = sync.observe_resume_history(history);
|
||||
sync.merge_pending_update(update);
|
||||
sync.defer_resume_update_until_append = sync.pending_update.is_some();
|
||||
}
|
||||
sync
|
||||
}
|
||||
|
||||
pub(crate) fn take_pending_update(&self) -> Option<PendingThreadMetadataPatch> {
|
||||
self.pending_update
|
||||
.clone()
|
||||
.map(|patch| PendingThreadMetadataPatch {
|
||||
patch,
|
||||
generation: self.pending_update_generation,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn take_pending_update_for_existing_history(
|
||||
&self,
|
||||
) -> Option<PendingThreadMetadataPatch> {
|
||||
if self.defer_create_update_until_history_exists {
|
||||
return None;
|
||||
}
|
||||
if self.defer_resume_update_until_append {
|
||||
return None;
|
||||
}
|
||||
self.take_pending_update()
|
||||
}
|
||||
|
||||
pub(crate) fn mark_pending_update_applied(&mut self, update: &PendingThreadMetadataPatch) {
|
||||
if self.pending_update_generation == update.generation {
|
||||
self.pending_update = None;
|
||||
}
|
||||
if update.patch.updated_at.is_some() {
|
||||
self.last_touch_persisted_at = Some(Instant::now());
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn observe_appended_items(
|
||||
&mut self,
|
||||
items: &[RolloutItem],
|
||||
) -> Option<PendingThreadMetadataPatch> {
|
||||
self.defer_create_update_until_history_exists = false;
|
||||
self.defer_resume_update_until_append = false;
|
||||
let affects_metadata = items
|
||||
.iter()
|
||||
.any(codex_state::rollout_item_affects_thread_metadata);
|
||||
let update = if affects_metadata {
|
||||
self.observe_items(items)?
|
||||
} else {
|
||||
thread_updated_at_touch()
|
||||
};
|
||||
self.merge_pending_update(Some(update));
|
||||
if !affects_metadata
|
||||
&& !self
|
||||
.pending_update
|
||||
.as_ref()
|
||||
.is_some_and(update_has_metadata_facts)
|
||||
&& self.last_touch_persisted_at.is_some_and(|last_touch| {
|
||||
Instant::now().duration_since(last_touch) < THREAD_UPDATED_AT_TOUCH_INTERVAL
|
||||
})
|
||||
{
|
||||
return None;
|
||||
}
|
||||
self.take_pending_update()
|
||||
}
|
||||
|
||||
fn observe_items(&mut self, items: &[RolloutItem]) -> Option<ThreadMetadataPatch> {
|
||||
self.observe_items_with_update(
|
||||
items,
|
||||
ThreadMetadataPatch {
|
||||
updated_at: Some(Utc::now()),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn observe_resume_history(&mut self, items: &[RolloutItem]) -> Option<ThreadMetadataPatch> {
|
||||
self.observe_items_with_update(items, ThreadMetadataPatch::default())
|
||||
}
|
||||
|
||||
fn observe_items_with_update(
|
||||
&mut self,
|
||||
items: &[RolloutItem],
|
||||
mut update: ThreadMetadataPatch,
|
||||
) -> Option<ThreadMetadataPatch> {
|
||||
if items.is_empty() {
|
||||
return None;
|
||||
}
|
||||
for item in items {
|
||||
match item {
|
||||
RolloutItem::SessionMeta(meta_line) if meta_line.meta.id == self.thread_id => {
|
||||
update.created_at = parse_session_timestamp(meta_line.meta.timestamp.as_str());
|
||||
update.source = Some(meta_line.meta.source.clone());
|
||||
update.thread_source = Some(meta_line.meta.thread_source);
|
||||
update.agent_nickname = Some(meta_line.meta.agent_nickname.clone());
|
||||
update.agent_role = Some(meta_line.meta.agent_role.clone());
|
||||
update.agent_path = Some(meta_line.meta.agent_path.clone());
|
||||
if let Some(model_provider) = meta_line.meta.model_provider.clone()
|
||||
&& !model_provider.is_empty()
|
||||
{
|
||||
update.model_provider = Some(model_provider);
|
||||
}
|
||||
if !meta_line.meta.cli_version.is_empty() {
|
||||
update.cli_version = Some(meta_line.meta.cli_version.clone());
|
||||
}
|
||||
if !meta_line.meta.cwd.as_os_str().is_empty() {
|
||||
self.cwd_seen = true;
|
||||
update.cwd = Some(meta_line.meta.cwd.clone());
|
||||
}
|
||||
if let Some(git_info) = meta_line.git.clone() {
|
||||
update.git_info = Some(git_info_patch_from_observation(git_info));
|
||||
}
|
||||
if let Some(memory_mode) = meta_line.meta.memory_mode.as_deref()
|
||||
&& let Some(memory_mode) = parse_memory_mode(memory_mode)
|
||||
{
|
||||
update.memory_mode = Some(memory_mode);
|
||||
}
|
||||
if let Some(dynamic_tools) = meta_line.meta.dynamic_tools.clone() {
|
||||
update.dynamic_tools = Some(dynamic_tools);
|
||||
}
|
||||
}
|
||||
RolloutItem::TurnContext(turn_ctx) => {
|
||||
if !self.cwd_seen && !turn_ctx.cwd.as_os_str().is_empty() {
|
||||
self.cwd_seen = true;
|
||||
update.cwd = Some(turn_ctx.cwd.clone());
|
||||
}
|
||||
update.model = Some(turn_ctx.model.clone());
|
||||
update.reasoning_effort = turn_ctx.effort;
|
||||
update.approval_mode = Some(turn_ctx.approval_policy);
|
||||
update.sandbox_policy = Some(turn_ctx.sandbox_policy.clone());
|
||||
}
|
||||
RolloutItem::EventMsg(EventMsg::UserMessage(user)) => {
|
||||
if let Some(preview) = user_message_preview(user) {
|
||||
if !self.first_user_message_seen {
|
||||
self.first_user_message_seen = true;
|
||||
update.first_user_message = Some(preview.clone());
|
||||
}
|
||||
if !self.preview_seen {
|
||||
self.preview_seen = true;
|
||||
update.preview = Some(preview);
|
||||
}
|
||||
}
|
||||
if !self.title_seen {
|
||||
let title = strip_user_message_prefix(user.message.as_str());
|
||||
if !title.is_empty() {
|
||||
self.title_seen = true;
|
||||
update.title = Some(title.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
RolloutItem::EventMsg(EventMsg::TokenCount(token_count)) => {
|
||||
if let Some(info) = token_count.info.as_ref() {
|
||||
update.token_usage = Some(info.total_token_usage.clone());
|
||||
}
|
||||
}
|
||||
RolloutItem::EventMsg(EventMsg::ThreadGoalUpdated(event)) => {
|
||||
if !self.preview_seen {
|
||||
let objective = event.goal.objective.trim();
|
||||
if !objective.is_empty() {
|
||||
self.preview_seen = true;
|
||||
update.preview = Some(objective.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
RolloutItem::SessionMeta(_)
|
||||
| RolloutItem::EventMsg(_)
|
||||
| RolloutItem::ResponseItem(_)
|
||||
| RolloutItem::Compacted(_) => {}
|
||||
}
|
||||
}
|
||||
Some(update)
|
||||
}
|
||||
|
||||
fn merge_pending_update(&mut self, update: Option<ThreadMetadataPatch>) {
|
||||
let Some(update) = update else {
|
||||
return;
|
||||
};
|
||||
match self.pending_update.as_mut() {
|
||||
Some(pending_update) => pending_update.merge(update),
|
||||
None => self.pending_update = Some(update),
|
||||
}
|
||||
self.pending_update_generation = self.pending_update_generation.wrapping_add(1);
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_memory_mode(value: &str) -> Option<ThreadMemoryMode> {
|
||||
match value {
|
||||
"enabled" => Some(ThreadMemoryMode::Enabled),
|
||||
"disabled" => Some(ThreadMemoryMode::Disabled),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_session_timestamp(value: &str) -> Option<DateTime<Utc>> {
|
||||
DateTime::parse_from_rfc3339(value)
|
||||
.map(|timestamp| timestamp.with_timezone(&Utc))
|
||||
.or_else(|_| {
|
||||
NaiveDateTime::parse_from_str(value, "%Y-%m-%dT%H-%M-%S")
|
||||
.map(|timestamp| DateTime::from_naive_utc_and_offset(timestamp, Utc))
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
|
||||
fn strip_user_message_prefix(text: &str) -> &str {
|
||||
match text.find(USER_MESSAGE_BEGIN) {
|
||||
Some(idx) => text[idx + USER_MESSAGE_BEGIN.len()..].trim(),
|
||||
None => text.trim(),
|
||||
}
|
||||
}
|
||||
|
||||
fn user_message_preview(user: &UserMessageEvent) -> Option<String> {
|
||||
let message = strip_user_message_prefix(user.message.as_str());
|
||||
if !message.is_empty() {
|
||||
return Some(message.to_string());
|
||||
}
|
||||
if user
|
||||
.images
|
||||
.as_ref()
|
||||
.is_some_and(|images| !images.is_empty())
|
||||
|| !user.local_images.is_empty()
|
||||
{
|
||||
return Some(IMAGE_ONLY_USER_MESSAGE_PLACEHOLDER.to_string());
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn thread_updated_at_touch() -> ThreadMetadataPatch {
|
||||
ThreadMetadataPatch {
|
||||
updated_at: Some(Utc::now()),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn update_has_metadata_facts(update: &ThreadMetadataPatch) -> bool {
|
||||
update.rollout_path.is_some()
|
||||
|| update.preview.is_some()
|
||||
|| update.title.is_some()
|
||||
|| update.model_provider.is_some()
|
||||
|| update.model.is_some()
|
||||
|| update.reasoning_effort.is_some()
|
||||
|| update.created_at.is_some()
|
||||
|| update.source.is_some()
|
||||
|| update.thread_source.is_some()
|
||||
|| update.agent_nickname.is_some()
|
||||
|| update.agent_role.is_some()
|
||||
|| update.agent_path.is_some()
|
||||
|| update.cwd.is_some()
|
||||
|| update.cli_version.is_some()
|
||||
|| update.approval_mode.is_some()
|
||||
|| update.sandbox_policy.is_some()
|
||||
|| update.token_usage.is_some()
|
||||
|| update.first_user_message.is_some()
|
||||
|| update.git_info.is_some()
|
||||
|| update.memory_mode.is_some()
|
||||
|| update.dynamic_tools.is_some()
|
||||
}
|
||||
|
||||
fn git_info_patch_from_observation(git_info: GitInfo) -> GitInfoPatch {
|
||||
GitInfoPatch {
|
||||
sha: git_info.commit_hash.map(|sha| Some(sha.0)),
|
||||
branch: git_info.branch.map(Some),
|
||||
origin_url: git_info.repository_url.map(Some),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use codex_protocol::protocol::CompactedItem;
|
||||
use codex_protocol::protocol::SessionMeta;
|
||||
use codex_protocol::protocol::SessionMetaLine;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::ThreadGoal;
|
||||
use codex_protocol::protocol::ThreadGoalStatus;
|
||||
use codex_protocol::protocol::ThreadGoalUpdatedEvent;
|
||||
use codex_protocol::protocol::UserMessageEvent;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
use crate::ThreadEventPersistenceMode;
|
||||
use crate::ThreadPersistenceMetadata;
|
||||
|
||||
#[test]
|
||||
fn resume_history_keeps_derived_metadata_pending_until_applied() {
|
||||
let thread_id = ThreadId::new();
|
||||
let mut sync = ThreadMetadataSync::for_resume(&resume_params(
|
||||
thread_id,
|
||||
vec![
|
||||
RolloutItem::SessionMeta(session_meta(thread_id)),
|
||||
RolloutItem::EventMsg(EventMsg::UserMessage(user_message("hello metadata"))),
|
||||
],
|
||||
));
|
||||
|
||||
let update = sync.take_pending_update().expect("pending metadata update");
|
||||
assert_eq!(
|
||||
update
|
||||
.patch
|
||||
.created_at
|
||||
.expect("created_at should come from session metadata")
|
||||
.to_rfc3339(),
|
||||
"2025-01-03T12:00:00+00:00"
|
||||
);
|
||||
assert_eq!(update.patch.preview.as_deref(), Some("hello metadata"));
|
||||
assert_eq!(update.patch.title.as_deref(), Some("hello metadata"));
|
||||
assert_eq!(
|
||||
update.patch.first_user_message.as_deref(),
|
||||
Some("hello metadata")
|
||||
);
|
||||
assert_eq!(update.patch.updated_at, None);
|
||||
assert!(
|
||||
sync.take_pending_update().is_some(),
|
||||
"taking the pending update should not drop retry state"
|
||||
);
|
||||
|
||||
sync.mark_pending_update_applied(&update);
|
||||
assert!(sync.take_pending_update().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn goal_update_sets_preview_without_overriding_existing_preview() {
|
||||
let thread_id = ThreadId::new();
|
||||
let sync = ThreadMetadataSync::for_resume(&resume_params(
|
||||
thread_id,
|
||||
vec![
|
||||
RolloutItem::EventMsg(EventMsg::ThreadGoalUpdated(goal_update(
|
||||
thread_id,
|
||||
"ship the refactor",
|
||||
))),
|
||||
RolloutItem::EventMsg(EventMsg::UserMessage(user_message("first user text"))),
|
||||
],
|
||||
));
|
||||
|
||||
let update = sync.take_pending_update().expect("pending metadata update");
|
||||
assert_eq!(update.patch.preview.as_deref(), Some("ship the refactor"));
|
||||
assert_eq!(
|
||||
update.patch.first_user_message.as_deref(),
|
||||
Some("first user text")
|
||||
);
|
||||
assert_eq!(update.patch.title.as_deref(), Some("first user text"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn later_user_messages_do_not_emit_existing_preview_fields() {
|
||||
let thread_id = ThreadId::new();
|
||||
let mut sync = ThreadMetadataSync::for_resume(&resume_params(
|
||||
thread_id,
|
||||
vec![RolloutItem::EventMsg(EventMsg::UserMessage(user_message(
|
||||
"first user text",
|
||||
)))],
|
||||
));
|
||||
let pending = sync.take_pending_update().expect("pending resume metadata");
|
||||
sync.mark_pending_update_applied(&pending);
|
||||
|
||||
let update = sync
|
||||
.observe_appended_items(&[RolloutItem::EventMsg(EventMsg::UserMessage(user_message(
|
||||
"later user text",
|
||||
)))])
|
||||
.expect("updated_at touch");
|
||||
|
||||
assert_eq!(update.patch.preview, None);
|
||||
assert_eq!(update.patch.title, None);
|
||||
assert_eq!(update.patch.first_user_message, None);
|
||||
assert!(update.patch.updated_at.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metadata_irrelevant_items_coalesce_updated_at_touches() {
|
||||
let thread_id = ThreadId::new();
|
||||
let mut sync = ThreadMetadataSync::for_resume(&resume_params(thread_id, Vec::new()));
|
||||
let item = RolloutItem::Compacted(CompactedItem {
|
||||
message: "compacted".to_string(),
|
||||
replacement_history: None,
|
||||
});
|
||||
|
||||
let first = sync
|
||||
.observe_appended_items(std::slice::from_ref(&item))
|
||||
.expect("first touch should apply immediately");
|
||||
assert!(first.patch.updated_at.is_some());
|
||||
sync.mark_pending_update_applied(&first);
|
||||
|
||||
assert!(
|
||||
sync.observe_appended_items(std::slice::from_ref(&item))
|
||||
.is_none(),
|
||||
"second touch inside the coalescing window should wait for a barrier"
|
||||
);
|
||||
assert!(
|
||||
sync.take_pending_update().is_some(),
|
||||
"coalesced touches still flush at the next barrier"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resume_history_waits_for_append_before_flushing_metadata() {
|
||||
let thread_id = ThreadId::new();
|
||||
let mut sync = ThreadMetadataSync::for_resume(&resume_params(
|
||||
thread_id,
|
||||
vec![
|
||||
RolloutItem::SessionMeta(session_meta(thread_id)),
|
||||
RolloutItem::EventMsg(EventMsg::UserMessage(user_message("hello metadata"))),
|
||||
],
|
||||
));
|
||||
|
||||
assert!(
|
||||
sync.take_pending_update_for_existing_history().is_none(),
|
||||
"resume-only metadata should not flush without a new append"
|
||||
);
|
||||
assert!(
|
||||
sync.observe_appended_items(&[RolloutItem::EventMsg(EventMsg::UserMessage(
|
||||
user_message("new append"),
|
||||
))])
|
||||
.is_some(),
|
||||
"the first append should flush resume metadata together with append metadata"
|
||||
);
|
||||
}
|
||||
|
||||
fn resume_params(thread_id: ThreadId, history: Vec<RolloutItem>) -> ResumeThreadParams {
|
||||
ResumeThreadParams {
|
||||
thread_id,
|
||||
rollout_path: None,
|
||||
history: Some(history),
|
||||
include_archived: false,
|
||||
metadata: ThreadPersistenceMetadata {
|
||||
cwd: None,
|
||||
model_provider: "test-provider".to_string(),
|
||||
memory_mode: ThreadMemoryMode::Enabled,
|
||||
},
|
||||
event_persistence_mode: ThreadEventPersistenceMode::Limited,
|
||||
}
|
||||
}
|
||||
|
||||
fn user_message(message: &str) -> UserMessageEvent {
|
||||
UserMessageEvent {
|
||||
message: message.to_string(),
|
||||
images: None,
|
||||
local_images: Vec::new(),
|
||||
text_elements: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn session_meta(thread_id: ThreadId) -> SessionMetaLine {
|
||||
SessionMetaLine {
|
||||
meta: SessionMeta {
|
||||
id: thread_id,
|
||||
timestamp: "2025-01-03T12:00:00Z".to_string(),
|
||||
source: SessionSource::Exec,
|
||||
..Default::default()
|
||||
},
|
||||
git: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn goal_update(thread_id: ThreadId, objective: &str) -> ThreadGoalUpdatedEvent {
|
||||
ThreadGoalUpdatedEvent {
|
||||
thread_id,
|
||||
turn_id: None,
|
||||
goal: ThreadGoal {
|
||||
thread_id,
|
||||
objective: objective.to_string(),
|
||||
status: ThreadGoalStatus::Active,
|
||||
token_budget: None,
|
||||
tokens_used: 0,
|
||||
time_used_seconds: 0,
|
||||
created_at: 0,
|
||||
updated_at: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,32 @@ use codex_protocol::protocol::ThreadMemoryMode as MemoryMode;
|
||||
use codex_protocol::protocol::ThreadSource;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use serde::Deserialize;
|
||||
use serde::Deserializer;
|
||||
use serde::Serialize;
|
||||
use serde::Serializer;
|
||||
|
||||
mod optional_option {
|
||||
use super::*;
|
||||
|
||||
pub fn serialize<T, S>(value: &Option<Option<T>>, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
T: Serialize,
|
||||
S: Serializer,
|
||||
{
|
||||
match value {
|
||||
Some(value) => value.serialize(serializer),
|
||||
None => serializer.serialize_none(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<Option<Option<T>>, D::Error>
|
||||
where
|
||||
T: Deserialize<'de>,
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Option::<T>::deserialize(deserializer).map(Some)
|
||||
}
|
||||
}
|
||||
|
||||
/// Controls how many event variants should be persisted for future replay.
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
@@ -348,32 +373,241 @@ pub struct StoredThread {
|
||||
}
|
||||
|
||||
/// Optional field patch where omission leaves a value unchanged and `Some(None)` clears it.
|
||||
pub type OptionalStringPatch = Option<Option<String>>;
|
||||
pub type ClearableField<T> = Option<Option<T>>;
|
||||
|
||||
/// Patch for thread Git metadata.
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct GitInfoPatch {
|
||||
/// Replacement commit SHA, clear request, or no-op.
|
||||
pub sha: OptionalStringPatch,
|
||||
#[serde(
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none",
|
||||
with = "optional_option"
|
||||
)]
|
||||
pub sha: ClearableField<String>,
|
||||
/// Replacement branch name, clear request, or no-op.
|
||||
pub branch: OptionalStringPatch,
|
||||
#[serde(
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none",
|
||||
with = "optional_option"
|
||||
)]
|
||||
pub branch: ClearableField<String>,
|
||||
/// Replacement origin URL, clear request, or no-op.
|
||||
pub origin_url: OptionalStringPatch,
|
||||
#[serde(
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none",
|
||||
with = "optional_option"
|
||||
)]
|
||||
pub origin_url: ClearableField<String>,
|
||||
}
|
||||
|
||||
/// Patch for mutable thread metadata.
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
impl GitInfoPatch {
|
||||
/// Merges another patch into this one using field-presence semantics.
|
||||
///
|
||||
/// Omitted fields in `next` leave the current patch unchanged. Present fields replace the
|
||||
/// current value, including clear requests like `Some(None)`.
|
||||
pub fn merge(&mut self, next: Self) {
|
||||
if next.sha.is_some() {
|
||||
self.sha = next.sha;
|
||||
}
|
||||
if next.branch.is_some() {
|
||||
self.branch = next.branch;
|
||||
}
|
||||
if next.origin_url.is_some() {
|
||||
self.origin_url = next.origin_url;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Patch for thread metadata.
|
||||
///
|
||||
/// Every field is literal: `None` leaves that field unchanged, while `Some`
|
||||
/// applies the supplied value. Fields whose value may itself be cleared use an
|
||||
/// inner `Option`, where `Some(None)` clears the field.
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub struct ThreadMetadataPatch {
|
||||
/// Replacement user-facing thread name.
|
||||
pub name: Option<String>,
|
||||
/// Replacement thread memory behavior.
|
||||
pub memory_mode: Option<MemoryMode>,
|
||||
/// Optional Git metadata patch.
|
||||
#[serde(
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none",
|
||||
with = "optional_option"
|
||||
)]
|
||||
pub name: ClearableField<String>,
|
||||
/// Known local rollout path for stores that expose one.
|
||||
pub rollout_path: Option<PathBuf>,
|
||||
/// Best available preview text for discovery/listing.
|
||||
pub preview: Option<String>,
|
||||
/// Best-effort title derived from history.
|
||||
pub title: Option<String>,
|
||||
/// Model provider associated with the thread.
|
||||
pub model_provider: Option<String>,
|
||||
/// Latest observed model.
|
||||
pub model: Option<String>,
|
||||
/// Latest observed reasoning effort.
|
||||
pub reasoning_effort: Option<ReasoningEffort>,
|
||||
/// Creation timestamp when known.
|
||||
pub created_at: Option<DateTime<Utc>>,
|
||||
/// Last update timestamp for this metadata observation.
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
/// Session source.
|
||||
pub source: Option<SessionSource>,
|
||||
/// Optional analytics source classification.
|
||||
#[serde(
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none",
|
||||
with = "optional_option"
|
||||
)]
|
||||
pub thread_source: ClearableField<ThreadSource>,
|
||||
/// Optional agent nickname.
|
||||
#[serde(
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none",
|
||||
with = "optional_option"
|
||||
)]
|
||||
pub agent_nickname: ClearableField<String>,
|
||||
/// Optional agent role.
|
||||
#[serde(
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none",
|
||||
with = "optional_option"
|
||||
)]
|
||||
pub agent_role: ClearableField<String>,
|
||||
/// Optional canonical agent path.
|
||||
#[serde(
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none",
|
||||
with = "optional_option"
|
||||
)]
|
||||
pub agent_path: ClearableField<String>,
|
||||
/// Working directory.
|
||||
pub cwd: Option<PathBuf>,
|
||||
/// CLI version that created the thread.
|
||||
pub cli_version: Option<String>,
|
||||
/// Approval mode.
|
||||
pub approval_mode: Option<AskForApproval>,
|
||||
/// Sandbox policy.
|
||||
pub sandbox_policy: Option<SandboxPolicy>,
|
||||
/// Last observed token usage.
|
||||
pub token_usage: Option<TokenUsage>,
|
||||
/// First user message observed for this thread.
|
||||
pub first_user_message: Option<String>,
|
||||
/// Git metadata patch.
|
||||
pub git_info: Option<GitInfoPatch>,
|
||||
/// Thread memory behavior.
|
||||
pub memory_mode: Option<MemoryMode>,
|
||||
/// Dynamic tools available to this thread.
|
||||
pub dynamic_tools: Option<Vec<DynamicToolSpec>>,
|
||||
}
|
||||
|
||||
impl ThreadMetadataPatch {
|
||||
/// Merges another patch into this one using field-presence semantics.
|
||||
///
|
||||
/// Omitted fields in `next` leave the current patch unchanged. Present fields replace the
|
||||
/// current value, including clear requests like `Some(None)`. Nested patches use the same
|
||||
/// semantics.
|
||||
pub fn merge(&mut self, next: Self) {
|
||||
if next.name.is_some() {
|
||||
self.name = next.name;
|
||||
}
|
||||
if next.rollout_path.is_some() {
|
||||
self.rollout_path = next.rollout_path;
|
||||
}
|
||||
if next.preview.is_some() {
|
||||
self.preview = next.preview;
|
||||
}
|
||||
if next.title.is_some() {
|
||||
self.title = next.title;
|
||||
}
|
||||
if next.model_provider.is_some() {
|
||||
self.model_provider = next.model_provider;
|
||||
}
|
||||
if next.model.is_some() {
|
||||
self.model = next.model;
|
||||
}
|
||||
if next.reasoning_effort.is_some() {
|
||||
self.reasoning_effort = next.reasoning_effort;
|
||||
}
|
||||
if next.created_at.is_some() {
|
||||
self.created_at = next.created_at;
|
||||
}
|
||||
if next.updated_at.is_some() {
|
||||
self.updated_at = next.updated_at;
|
||||
}
|
||||
if next.source.is_some() {
|
||||
self.source = next.source;
|
||||
}
|
||||
if next.thread_source.is_some() {
|
||||
self.thread_source = next.thread_source;
|
||||
}
|
||||
if next.agent_nickname.is_some() {
|
||||
self.agent_nickname = next.agent_nickname;
|
||||
}
|
||||
if next.agent_role.is_some() {
|
||||
self.agent_role = next.agent_role;
|
||||
}
|
||||
if next.agent_path.is_some() {
|
||||
self.agent_path = next.agent_path;
|
||||
}
|
||||
if next.cwd.is_some() {
|
||||
self.cwd = next.cwd;
|
||||
}
|
||||
if next.cli_version.is_some() {
|
||||
self.cli_version = next.cli_version;
|
||||
}
|
||||
if next.approval_mode.is_some() {
|
||||
self.approval_mode = next.approval_mode;
|
||||
}
|
||||
if next.sandbox_policy.is_some() {
|
||||
self.sandbox_policy = next.sandbox_policy;
|
||||
}
|
||||
if next.token_usage.is_some() {
|
||||
self.token_usage = next.token_usage;
|
||||
}
|
||||
if next.first_user_message.is_some() {
|
||||
self.first_user_message = next.first_user_message;
|
||||
}
|
||||
if let Some(git_info) = next.git_info {
|
||||
self.git_info
|
||||
.get_or_insert_with(GitInfoPatch::default)
|
||||
.merge(git_info);
|
||||
}
|
||||
if next.memory_mode.is_some() {
|
||||
self.memory_mode = next.memory_mode;
|
||||
}
|
||||
if next.dynamic_tools.is_some() {
|
||||
self.dynamic_tools = next.dynamic_tools;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.name.is_none()
|
||||
&& self.rollout_path.is_none()
|
||||
&& self.preview.is_none()
|
||||
&& self.title.is_none()
|
||||
&& self.model_provider.is_none()
|
||||
&& self.model.is_none()
|
||||
&& self.reasoning_effort.is_none()
|
||||
&& self.created_at.is_none()
|
||||
&& self.updated_at.is_none()
|
||||
&& self.source.is_none()
|
||||
&& self.thread_source.is_none()
|
||||
&& self.agent_nickname.is_none()
|
||||
&& self.agent_role.is_none()
|
||||
&& self.agent_path.is_none()
|
||||
&& self.cwd.is_none()
|
||||
&& self.cli_version.is_none()
|
||||
&& self.approval_mode.is_none()
|
||||
&& self.sandbox_policy.is_none()
|
||||
&& self.token_usage.is_none()
|
||||
&& self.first_user_message.is_none()
|
||||
&& self.git_info.is_none()
|
||||
&& self.memory_mode.is_none()
|
||||
&& self.dynamic_tools.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameters for patching mutable thread metadata.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct UpdateThreadMetadataParams {
|
||||
/// Thread id to update.
|
||||
pub thread_id: ThreadId,
|
||||
@@ -389,3 +623,116 @@ pub struct ArchiveThreadParams {
|
||||
/// Thread id to archive or unarchive.
|
||||
pub thread_id: ThreadId,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn thread_metadata_patch_round_trips_optional_clears() {
|
||||
let patch = ThreadMetadataPatch {
|
||||
name: Some(None),
|
||||
thread_source: Some(None),
|
||||
agent_nickname: Some(None),
|
||||
agent_role: Some(None),
|
||||
agent_path: Some(None),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let value = serde_json::to_value(&patch).expect("serialize patch");
|
||||
assert_eq!(value["name"], json!(null));
|
||||
assert_eq!(value["thread_source"], json!(null));
|
||||
assert_eq!(value["agent_nickname"], json!(null));
|
||||
assert_eq!(value["agent_role"], json!(null));
|
||||
assert_eq!(value["agent_path"], json!(null));
|
||||
|
||||
let decoded: ThreadMetadataPatch =
|
||||
serde_json::from_value(value).expect("deserialize patch");
|
||||
assert_eq!(decoded.name, Some(None));
|
||||
assert_eq!(decoded.thread_source, Some(None));
|
||||
assert_eq!(decoded.agent_nickname, Some(None));
|
||||
assert_eq!(decoded.agent_role, Some(None));
|
||||
assert_eq!(decoded.agent_path, Some(None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn git_info_patch_round_trips_optional_clears() {
|
||||
let patch = ThreadMetadataPatch {
|
||||
git_info: Some(GitInfoPatch {
|
||||
sha: None,
|
||||
branch: Some(Some("main".to_string())),
|
||||
origin_url: Some(None),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let value = serde_json::to_value(&patch).expect("serialize patch");
|
||||
assert_eq!(
|
||||
value["git_info"],
|
||||
json!({
|
||||
"branch": "main",
|
||||
"origin_url": null,
|
||||
})
|
||||
);
|
||||
|
||||
let decoded: ThreadMetadataPatch =
|
||||
serde_json::from_value(value).expect("deserialize patch");
|
||||
assert_eq!(
|
||||
decoded.git_info,
|
||||
Some(GitInfoPatch {
|
||||
sha: None,
|
||||
branch: Some(Some("main".to_string())),
|
||||
origin_url: Some(None),
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thread_metadata_patch_accepts_missing_fields() {
|
||||
let decoded: ThreadMetadataPatch =
|
||||
serde_json::from_value(json!({})).expect("deserialize legacy patch");
|
||||
|
||||
assert!(decoded.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thread_metadata_patch_merge_uses_presence_semantics() {
|
||||
let mut current = ThreadMetadataPatch {
|
||||
name: Some(Some("old name".to_string())),
|
||||
preview: Some("old preview".to_string()),
|
||||
git_info: Some(GitInfoPatch {
|
||||
sha: Some(Some("abc123".to_string())),
|
||||
branch: Some(Some("main".to_string())),
|
||||
origin_url: None,
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
current.merge(ThreadMetadataPatch {
|
||||
name: Some(None),
|
||||
preview: None,
|
||||
title: Some("new title".to_string()),
|
||||
git_info: Some(GitInfoPatch {
|
||||
sha: None,
|
||||
branch: Some(Some("feature".to_string())),
|
||||
origin_url: Some(None),
|
||||
}),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
assert_eq!(current.name, Some(None));
|
||||
assert_eq!(current.preview.as_deref(), Some("old preview"));
|
||||
assert_eq!(current.title.as_deref(), Some("new title"));
|
||||
assert_eq!(
|
||||
current.git_info,
|
||||
Some(GitInfoPatch {
|
||||
sha: Some(Some("abc123".to_string())),
|
||||
branch: Some(Some("feature".to_string())),
|
||||
origin_url: Some(None),
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user