Files
codex/codex-rs/thread-store/src/in_memory.rs
T
alexsong-oai 1acb722e8a Support thread-level originator overrides (#29477)
## Why

Work(TPP) threads can be launched from the Desktop app, but if they all
keep the Desktop app's default originator then downstream attribution
cannot distinguish local Work launches from cloud-backed Work launches.
`thread/start.serviceName` already carries that launch signal, while
`SessionMeta.originator` is the durable thread-level value that survives
resume and fork.

This change converts the Desktop Work service names into an effective
originator at thread creation time, persists that originator with the
thread, and keeps using it for later model requests and memory writes.

## What changed

- Map `CODEX_WORK_LOCAL` and `CODEX_WORK_CLOUD` service names to
per-thread originators, while preserving
`CODEX_INTERNAL_ORIGINATOR_OVERRIDE` as the highest-precedence override.
- Persist the effective originator in `SessionMeta.originator`, read it
back on resume/fork, and inherit the parent originator for subagent
spawns when there is no persisted session metadata.
- Handle truncated `SpawnAgentForkMode::LastNTurns` forks by falling
back to the live parent originator when the forked history no longer
includes `SessionMeta`.
- Thread the per-thread originator through Responses headers,
websocket/compaction request paths, thread-store creation, rollout
metadata, and memory stage-one telemetry.

## Verification

- `just test -p codex-core
agent::control::tests::spawn_thread_subagent_inherits_parent_originator_without_fork
agent::control::tests::spawn_thread_subagent_fork_last_n_turns_inherits_parent_originator_without_session_meta
thread_manager::tests::originator_override_precedes_service_name_remapping`
- `just test -p codex-core
agent::control::tests::resume_thread_subagent_restores_stored_metadata_and_effective_multi_agent_mode`
- `just test -p codex-memories-write`
- `just fix -p codex-core -p codex-memories-write`
- `git diff --check`
2026-06-23 17:23:38 -07:00

593 lines
21 KiB
Rust

use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::MutexGuard;
use std::sync::OnceLock;
use chrono::Utc;
use codex_protocol::ThreadId;
use codex_protocol::models::PermissionProfile;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::SessionContextWindow;
use codex_protocol::protocol::SessionMeta;
use codex_protocol::protocol::SessionMetaLine;
use codex_protocol::protocol::ThreadMemoryMode;
use codex_rollout::persisted_rollout_items;
use crate::AppendThreadItemsParams;
use crate::ArchiveThreadParams;
use crate::CreateThreadParams;
use crate::DeleteThreadParams;
use crate::ListThreadsParams;
use crate::LoadThreadHistoryParams;
use crate::ReadThreadByRolloutPathParams;
use crate::ReadThreadParams;
use crate::ResumeThreadParams;
use crate::StoredThread;
use crate::StoredThreadHistory;
use crate::ThreadMetadataPatch;
use crate::ThreadPage;
use crate::ThreadStore;
use crate::ThreadStoreError;
use crate::ThreadStoreFuture;
use crate::ThreadStoreResult;
use crate::UpdateThreadMetadataParams;
static IN_MEMORY_THREAD_STORES: OnceLock<Mutex<HashMap<String, Arc<InMemoryThreadStore>>>> =
OnceLock::new();
fn stores() -> &'static Mutex<HashMap<String, Arc<InMemoryThreadStore>>> {
IN_MEMORY_THREAD_STORES.get_or_init(|| Mutex::new(HashMap::new()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ListItemsParams;
use crate::ListTurnsParams;
use crate::SortDirection;
use crate::StoredTurnItemsView;
use crate::ThreadPersistenceMetadata;
use crate::ThreadSortKey;
use codex_protocol::models::BaseInstructions;
use codex_protocol::protocol::SessionSource;
#[tokio::test]
async fn default_turn_pagination_methods_return_unsupported() {
let store = InMemoryThreadStore::default();
let thread_id = ThreadId::default();
let turns_err = store
.list_turns(ListTurnsParams {
thread_id,
include_archived: true,
cursor: None,
page_size: 10,
sort_direction: SortDirection::Asc,
items_view: StoredTurnItemsView::Summary,
})
.await
.expect_err("default list_turns should be unsupported");
assert!(matches!(
turns_err,
ThreadStoreError::Unsupported {
operation: "list_turns"
}
));
let items_err = store
.list_items(ListItemsParams {
thread_id,
turn_id: None,
include_archived: true,
cursor: None,
page_size: 10,
sort_direction: SortDirection::Asc,
})
.await
.expect_err("default list_items should be unsupported");
assert!(matches!(
items_err,
ThreadStoreError::Unsupported {
operation: "list_items"
}
));
}
#[tokio::test]
async fn list_threads_filters_by_parent_thread_id() {
let store = InMemoryThreadStore::default();
let parent_thread_id = ThreadId::default();
let child_thread_id =
ThreadId::from_string("00000000-0000-0000-0000-000000000001").expect("valid thread id");
let unrelated_thread_id =
ThreadId::from_string("00000000-0000-0000-0000-000000000002").expect("valid thread id");
for (thread_id, parent_thread_id) in [
(child_thread_id, Some(parent_thread_id)),
(unrelated_thread_id, None),
] {
store
.create_thread(CreateThreadParams {
session_id: thread_id.into(),
thread_id,
extra_config: None,
forked_from_id: None,
parent_thread_id,
source: SessionSource::Exec,
thread_source: None,
originator: "test_originator".to_string(),
base_instructions: BaseInstructions::default(),
dynamic_tools: Vec::new(),
multi_agent_version: None,
initial_window_id: uuid::Uuid::now_v7().to_string(),
metadata: ThreadPersistenceMetadata {
cwd: None,
model_provider: "test-provider".to_string(),
memory_mode: ThreadMemoryMode::Enabled,
},
})
.await
.expect("create thread");
}
let page = ThreadStore::list_threads(
&store,
ListThreadsParams {
page_size: 10,
cursor: None,
sort_key: ThreadSortKey::CreatedAt,
sort_direction: SortDirection::Desc,
allowed_sources: Vec::new(),
model_providers: None,
cwd_filters: None,
archived: false,
search_term: None,
parent_thread_id: Some(parent_thread_id),
use_state_db_only: false,
},
)
.await
.expect("list child threads");
assert_eq!(
page.items
.into_iter()
.map(|item| item.thread_id)
.collect::<Vec<_>>(),
vec![child_thread_id]
);
}
}
fn stores_guard() -> MutexGuard<'static, HashMap<String, Arc<InMemoryThreadStore>>> {
match stores().lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
/// Recorded call counts for [`InMemoryThreadStore`].
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct InMemoryThreadStoreCalls {
pub create_thread: usize,
pub resume_thread: usize,
pub append_items: usize,
pub persist_thread: usize,
pub flush_thread: usize,
pub shutdown_thread: usize,
pub discard_thread: usize,
pub load_history: usize,
pub read_thread: usize,
pub read_thread_with_history: usize,
pub read_thread_by_rollout_path: usize,
pub list_threads: usize,
pub update_thread_metadata: usize,
pub archive_thread: usize,
pub unarchive_thread: usize,
pub delete_thread: usize,
}
/// In-memory [`ThreadStore`] implementation for tests and debug configs.
///
/// Test and debug configs can select this store by id, letting tests exercise
/// config-driven non-local persistence without requiring the real remote gRPC
/// service.
#[derive(Default)]
pub struct InMemoryThreadStore {
state: tokio::sync::Mutex<InMemoryThreadStoreState>,
}
#[derive(Default)]
struct InMemoryThreadStoreState {
calls: InMemoryThreadStoreCalls,
created_threads: HashMap<ThreadId, CreateThreadParams>,
histories: HashMap<ThreadId, Vec<RolloutItem>>,
metadata_updates: HashMap<ThreadId, ThreadMetadataPatch>,
names: HashMap<ThreadId, Option<String>>,
rollout_paths: HashMap<PathBuf, ThreadId>,
}
impl InMemoryThreadStore {
/// Returns the store associated with `id`, creating it if needed.
pub fn for_id(id: impl Into<String>) -> Arc<Self> {
let id = id.into();
let mut stores = stores_guard();
stores
.entry(id)
.or_insert_with(|| Arc::new(Self::default()))
.clone()
}
/// Removes a shared in-memory store for `id`.
pub fn remove_id(id: &str) -> Option<Arc<Self>> {
stores_guard().remove(id)
}
/// Returns the calls observed by this store.
pub async fn calls(&self) -> InMemoryThreadStoreCalls {
self.state.lock().await.calls.clone()
}
async fn create_thread(&self, params: CreateThreadParams) -> ThreadStoreResult<()> {
let mut state = self.state.lock().await;
state.calls.create_thread += 1;
let session_meta = SessionMeta {
session_id: params.session_id,
id: params.thread_id,
forked_from_id: params.forked_from_id,
parent_thread_id: params.parent_thread_id,
cwd: params.metadata.cwd.clone().unwrap_or_default(),
agent_nickname: params.source.get_nickname(),
agent_role: params.source.get_agent_role(),
agent_path: params.source.get_agent_path().map(Into::into),
originator: params.originator.clone(),
source: params.source.clone(),
thread_source: params.thread_source.clone(),
model_provider: Some(params.metadata.model_provider.clone()),
base_instructions: Some(params.base_instructions.clone()),
dynamic_tools: (!params.dynamic_tools.is_empty()).then(|| params.dynamic_tools.clone()),
memory_mode: matches!(params.metadata.memory_mode, ThreadMemoryMode::Disabled)
.then_some("disabled".to_string()),
multi_agent_version: params.multi_agent_version,
context_window: Some(SessionContextWindow::new(params.initial_window_id.clone())),
..SessionMeta::default()
};
state
.histories
.entry(params.thread_id)
.or_default()
.push(RolloutItem::SessionMeta(SessionMetaLine {
meta: session_meta,
git: None,
}));
state.created_threads.insert(params.thread_id, params);
Ok(())
}
async fn resume_thread(&self, params: ResumeThreadParams) -> ThreadStoreResult<()> {
let mut state = self.state.lock().await;
state.calls.resume_thread += 1;
if let Some(history) = params.history {
state
.histories
.insert(params.thread_id, Arc::unwrap_or_clone(history));
} else {
state.histories.entry(params.thread_id).or_default();
}
if let Some(rollout_path) = params.rollout_path {
state.rollout_paths.insert(rollout_path, params.thread_id);
}
Ok(())
}
async fn append_items(&self, params: AppendThreadItemsParams) -> ThreadStoreResult<()> {
let canonical_items = persisted_rollout_items(params.items.as_slice());
if canonical_items.is_empty() {
return Ok(());
}
let mut state = self.state.lock().await;
state.calls.append_items += 1;
state
.histories
.entry(params.thread_id)
.or_default()
.extend(canonical_items);
Ok(())
}
async fn load_history(
&self,
params: LoadThreadHistoryParams,
) -> ThreadStoreResult<StoredThreadHistory> {
let mut state = self.state.lock().await;
state.calls.load_history += 1;
let items = state.histories.get(&params.thread_id).cloned().ok_or(
ThreadStoreError::ThreadNotFound {
thread_id: params.thread_id,
},
)?;
Ok(StoredThreadHistory {
thread_id: params.thread_id,
items,
})
}
async fn read_thread(&self, params: ReadThreadParams) -> ThreadStoreResult<StoredThread> {
let mut state = self.state.lock().await;
state.calls.read_thread += 1;
if params.include_history {
state.calls.read_thread_with_history += 1;
}
stored_thread_from_state(&state, params.thread_id, params.include_history)
}
async fn read_thread_by_rollout_path(
&self,
params: ReadThreadByRolloutPathParams,
) -> ThreadStoreResult<StoredThread> {
let mut state = self.state.lock().await;
state.calls.read_thread_by_rollout_path += 1;
let Some(thread_id) = state.rollout_paths.get(&params.rollout_path).copied() else {
return Err(ThreadStoreError::InvalidRequest {
message: format!(
"in-memory thread store does not know rollout path {}",
params.rollout_path.display()
),
});
};
stored_thread_from_state(&state, thread_id, params.include_history)
}
async fn list_threads(&self) -> ThreadStoreResult<ThreadPage> {
let mut state = self.state.lock().await;
state.calls.list_threads += 1;
let mut items = state
.created_threads
.keys()
.map(|thread_id| {
stored_thread_from_state(&state, *thread_id, /*include_history*/ false)
})
.collect::<ThreadStoreResult<Vec<_>>>()?;
items.sort_by_key(|item| item.thread_id.to_string());
Ok(ThreadPage {
items,
next_cursor: None,
})
}
async fn update_thread_metadata(
&self,
params: UpdateThreadMetadataParams,
) -> ThreadStoreResult<StoredThread> {
let mut state = self.state.lock().await;
state.calls.update_thread_metadata += 1;
if let Some(name) = params.patch.name.clone() {
state.names.insert(params.thread_id, name);
}
state
.metadata_updates
.entry(params.thread_id)
.or_default()
.merge(params.patch);
stored_thread_from_state(&state, params.thread_id, /*include_history*/ false)
}
async fn delete_thread(&self, params: DeleteThreadParams) -> ThreadStoreResult<()> {
let mut state = self.state.lock().await;
state.calls.delete_thread += 1;
let existed = state.histories.remove(&params.thread_id).is_some();
state.created_threads.remove(&params.thread_id);
state.names.remove(&params.thread_id);
state.metadata_updates.remove(&params.thread_id);
state
.rollout_paths
.retain(|_, thread_id| *thread_id != params.thread_id);
if existed {
Ok(())
} else {
Err(ThreadStoreError::ThreadNotFound {
thread_id: params.thread_id,
})
}
}
}
impl ThreadStore for InMemoryThreadStore {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn create_thread(&self, params: CreateThreadParams) -> ThreadStoreFuture<'_, ()> {
Box::pin(InMemoryThreadStore::create_thread(self, params))
}
fn resume_thread(&self, params: ResumeThreadParams) -> ThreadStoreFuture<'_, ()> {
Box::pin(InMemoryThreadStore::resume_thread(self, params))
}
fn append_items(&self, params: AppendThreadItemsParams) -> ThreadStoreFuture<'_, ()> {
Box::pin(InMemoryThreadStore::append_items(self, params))
}
fn persist_thread(&self, _thread_id: ThreadId) -> ThreadStoreFuture<'_, ()> {
Box::pin(async move {
self.state.lock().await.calls.persist_thread += 1;
Ok(())
})
}
fn flush_thread(&self, _thread_id: ThreadId) -> ThreadStoreFuture<'_, ()> {
Box::pin(async move {
self.state.lock().await.calls.flush_thread += 1;
Ok(())
})
}
fn shutdown_thread(&self, _thread_id: ThreadId) -> ThreadStoreFuture<'_, ()> {
Box::pin(async move {
self.state.lock().await.calls.shutdown_thread += 1;
Ok(())
})
}
fn discard_thread(&self, _thread_id: ThreadId) -> ThreadStoreFuture<'_, ()> {
Box::pin(async move {
self.state.lock().await.calls.discard_thread += 1;
Ok(())
})
}
fn load_history(
&self,
params: LoadThreadHistoryParams,
) -> ThreadStoreFuture<'_, StoredThreadHistory> {
Box::pin(InMemoryThreadStore::load_history(self, params))
}
fn read_thread(&self, params: ReadThreadParams) -> ThreadStoreFuture<'_, StoredThread> {
Box::pin(InMemoryThreadStore::read_thread(self, params))
}
fn read_thread_by_rollout_path(
&self,
params: ReadThreadByRolloutPathParams,
) -> ThreadStoreFuture<'_, StoredThread> {
Box::pin(InMemoryThreadStore::read_thread_by_rollout_path(
self, params,
))
}
fn list_threads(&self, params: ListThreadsParams) -> ThreadStoreFuture<'_, ThreadPage> {
Box::pin(async move {
let mut page = InMemoryThreadStore::list_threads(self).await?;
if let Some(parent_thread_id) = params.parent_thread_id {
page.items
.retain(|thread| thread.parent_thread_id == Some(parent_thread_id));
}
Ok(page)
})
}
fn update_thread_metadata(
&self,
params: UpdateThreadMetadataParams,
) -> ThreadStoreFuture<'_, StoredThread> {
Box::pin(InMemoryThreadStore::update_thread_metadata(self, params))
}
fn archive_thread(&self, _params: ArchiveThreadParams) -> ThreadStoreFuture<'_, ()> {
Box::pin(async move {
self.state.lock().await.calls.archive_thread += 1;
Ok(())
})
}
fn unarchive_thread(&self, params: ArchiveThreadParams) -> ThreadStoreFuture<'_, StoredThread> {
Box::pin(async move {
let mut state = self.state.lock().await;
state.calls.unarchive_thread += 1;
stored_thread_from_state(&state, params.thread_id, /*include_history*/ false)
})
}
fn delete_thread(&self, params: DeleteThreadParams) -> ThreadStoreFuture<'_, ()> {
Box::pin(InMemoryThreadStore::delete_thread(self, params))
}
}
fn stored_thread_from_state(
state: &InMemoryThreadStoreState,
thread_id: ThreadId,
include_history: bool,
) -> ThreadStoreResult<StoredThread> {
let created = state
.created_threads
.get(&thread_id)
.ok_or(ThreadStoreError::ThreadNotFound { thread_id })?;
let history_items = state.histories.get(&thread_id).cloned().unwrap_or_default();
let history = include_history.then(|| StoredThreadHistory {
thread_id,
items: history_items.clone(),
});
let name = state.names.get(&thread_id).cloned().flatten();
let metadata = state.metadata_updates.get(&thread_id);
let rollout_path = state
.rollout_paths
.iter()
.find_map(|(path, mapped_thread_id)| {
(*mapped_thread_id == thread_id).then(|| path.clone())
});
Ok(StoredThread {
thread_id,
extra_config: created.extra_config.clone(),
rollout_path: metadata
.and_then(|metadata| metadata.rollout_path.clone())
.or(rollout_path),
forked_from_id: created.forked_from_id,
parent_thread_id: created.parent_thread_id,
preview: metadata
.and_then(|metadata| metadata.preview.clone())
.unwrap_or_default(),
name,
model_provider: metadata
.and_then(|metadata| metadata.model_provider.clone())
.unwrap_or_else(|| "test".to_string()),
model: metadata.and_then(|metadata| metadata.model.clone()),
reasoning_effort: metadata.and_then(|metadata| metadata.reasoning_effort.clone()),
created_at: metadata
.and_then(|metadata| metadata.created_at)
.unwrap_or_else(Utc::now),
updated_at: metadata
.and_then(|metadata| metadata.updated_at)
.unwrap_or_else(Utc::now),
recency_at: metadata
.and_then(|metadata| metadata.advance_recency_at.or(metadata.updated_at))
.unwrap_or_else(Utc::now),
archived_at: None,
cwd: metadata
.and_then(|metadata| metadata.cwd.clone())
.unwrap_or_default(),
cli_version: metadata
.and_then(|metadata| metadata.cli_version.clone())
.unwrap_or_else(|| "test".to_string()),
source: metadata
.and_then(|metadata| metadata.source.clone())
.unwrap_or_else(|| created.source.clone()),
thread_source: metadata
.and_then(|metadata| metadata.thread_source.clone())
.unwrap_or_else(|| created.thread_source.clone()),
agent_nickname: metadata.and_then(|metadata| metadata.agent_nickname.clone().flatten()),
agent_role: metadata.and_then(|metadata| metadata.agent_role.clone().flatten()),
agent_path: metadata.and_then(|metadata| metadata.agent_path.clone().flatten()),
git_info: metadata.and_then(git_info_from_patch),
approval_mode: metadata
.and_then(|metadata| metadata.approval_mode)
.unwrap_or(AskForApproval::Never),
permission_profile: metadata
.and_then(|metadata| metadata.permission_profile.clone())
.unwrap_or_else(PermissionProfile::read_only),
token_usage: metadata.and_then(|metadata| metadata.token_usage.clone()),
first_user_message: metadata.and_then(|metadata| metadata.first_user_message.clone()),
history,
})
}
fn git_info_from_patch(patch: &ThreadMetadataPatch) -> Option<codex_protocol::protocol::GitInfo> {
let git_info = patch.git_info.as_ref()?;
let sha = git_info.sha.clone().flatten();
let branch = git_info.branch.clone().flatten();
let origin_url = git_info.origin_url.clone().flatten();
if sha.is_none() && branch.is_none() && origin_url.is_none() {
return None;
}
Some(codex_protocol::protocol::GitInfo {
commit_hash: sha.as_deref().map(codex_git_utils::GitSha::new),
branch,
repository_url: origin_url,
})
}