From 5935a906194a74b6c6d0bc6f92b70d15f7cdeffa Mon Sep 17 00:00:00 2001 From: jif Date: Wed, 17 Jun 2026 15:18:39 +0100 Subject: [PATCH] app-server: keep the model cache warm (#28699) ## Why The app server is long-lived, but its shared model cache otherwise refreshes only when a caller needs it. Once the five-minute cache expires, starting a thread or calling `model/list` can wait for `/models` on the request path. Refresh the cache in the background before it expires so foreground callers normally use fresh local state. ## What changed - Start an app-server worker that refreshes models immediately and then every three minutes using the existing models-manager API. - Hold only a weak reference to the models manager between refreshes, so the worker does not extend its lifetime. - Stop scheduling refreshes when the app-server lifecycle handle is shut down or dropped. A refresh already in progress is allowed to finish. - Adjust affected app-server test fixtures to distinguish the background `/models` probe from the connection they are testing. The existing models-manager cache, refresh strategies, auth handling, ETag behavior, and concurrency semantics are unchanged. ## Testing - `models_refresh_worker::tests::refreshes_immediately_periodically_and_stops_when_dropped` - `suite::v2::remote_control::listen_off_honors_persisted_remote_control_enable` - `suite::v2::attestation::attestation_generate_round_trip_adds_header_to_responses_websocket_handshake` --- codex-rs/app-server/src/lib.rs | 1 + codex-rs/app-server/src/message_processor.rs | 8 ++ .../app-server/src/models_refresh_worker.rs | 65 ++++++++++++++ .../src/models_refresh_worker_tests.rs | 90 +++++++++++++++++++ .../app-server/tests/suite/v2/attestation.rs | 43 ++++----- .../tests/suite/v2/remote_control.rs | 82 ++++++++++------- codex-rs/core/tests/common/responses.rs | 7 +- 7 files changed, 238 insertions(+), 58 deletions(-) create mode 100644 codex-rs/app-server/src/models_refresh_worker.rs create mode 100644 codex-rs/app-server/src/models_refresh_worker_tests.rs diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 34bb39da9..cac2db54e 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -101,6 +101,7 @@ pub mod in_process; mod mcp_refresh; mod message_processor; mod models; +mod models_refresh_worker; mod outgoing_message; mod request_processors; mod request_serialization; diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index c32b92c4e..3c3600008 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -92,6 +92,8 @@ use tokio::time::timeout; use tokio_util::sync::CancellationToken; use tracing::Instrument; +use crate::models_refresh_worker::ModelsRefreshWorker; + const EXTERNAL_AUTH_REFRESH_TIMEOUT: Duration = Duration::from_secs(10); const CONNECTION_RPC_DRAIN_TIMEOUT: Duration = Duration::from_secs(/*secs*/ 30); @@ -182,6 +184,7 @@ impl ExternalAuth for ExternalAuthRefreshBridge { pub(crate) struct MessageProcessor { outgoing: Arc, + models_refresh_worker: ModelsRefreshWorker, skills_watcher: Arc, account_processor: AccountRequestProcessor, apps_processor: AppsRequestProcessor, @@ -371,6 +374,8 @@ impl MessageProcessor { )), ) }); + let models_manager = thread_manager.get_models_manager(); + let models_refresh_worker = crate::models_refresh_worker::spawn(&models_manager); thread_manager .plugins_manager() .set_analytics_events_client(analytics_events_client.clone()); @@ -537,6 +542,7 @@ impl MessageProcessor { Self { outgoing, + models_refresh_worker, skills_watcher, account_processor, apps_processor, @@ -566,6 +572,7 @@ impl MessageProcessor { pub(crate) fn clear_runtime_references(&self) { self.account_processor.clear_external_auth(); self.apps_processor.shutdown(); + self.models_refresh_worker.shutdown(); self.skills_watcher.shutdown(); } @@ -742,6 +749,7 @@ impl MessageProcessor { } pub(crate) async fn drain_background_tasks(&self) { + self.models_refresh_worker.shutdown(); self.thread_processor.drain_background_tasks().await; } diff --git a/codex-rs/app-server/src/models_refresh_worker.rs b/codex-rs/app-server/src/models_refresh_worker.rs new file mode 100644 index 000000000..475c8acab --- /dev/null +++ b/codex-rs/app-server/src/models_refresh_worker.rs @@ -0,0 +1,65 @@ +use std::sync::Arc; +use std::time::Duration; + +use codex_models_manager::manager::RefreshStrategy; +use codex_models_manager::manager::SharedModelsManager; +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; + +const MODELS_REFRESH_INTERVAL: Duration = Duration::from_secs(3 * 60); + +#[derive(Debug)] +pub(crate) struct ModelsRefreshWorker { + shutdown: CancellationToken, + _task: JoinHandle<()>, +} + +impl ModelsRefreshWorker { + pub(crate) fn shutdown(&self) { + self.shutdown.cancel(); + } +} + +impl Drop for ModelsRefreshWorker { + fn drop(&mut self) { + self.shutdown(); + } +} + +pub(crate) fn spawn(models_manager: &SharedModelsManager) -> ModelsRefreshWorker { + spawn_with_interval(models_manager, MODELS_REFRESH_INTERVAL) +} + +fn spawn_with_interval( + models_manager: &SharedModelsManager, + refresh_interval: Duration, +) -> ModelsRefreshWorker { + let models_manager = Arc::downgrade(models_manager); + let shutdown = CancellationToken::new(); + let worker_shutdown = shutdown.clone(); + let task = tokio::spawn(async move { + loop { + if worker_shutdown.is_cancelled() { + break; + } + let Some(models_manager) = models_manager.upgrade() else { + break; + }; + models_manager.list_models(RefreshStrategy::Online).await; + drop(models_manager); + + tokio::select! { + _ = worker_shutdown.cancelled() => break, + _ = tokio::time::sleep(refresh_interval) => {} + } + } + }); + ModelsRefreshWorker { + shutdown, + _task: task, + } +} + +#[cfg(test)] +#[path = "models_refresh_worker_tests.rs"] +mod tests; diff --git a/codex-rs/app-server/src/models_refresh_worker_tests.rs b/codex-rs/app-server/src/models_refresh_worker_tests.rs new file mode 100644 index 000000000..4d7ff48d3 --- /dev/null +++ b/codex-rs/app-server/src/models_refresh_worker_tests.rs @@ -0,0 +1,90 @@ +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; +use std::time::Duration; + +use codex_models_manager::manager::ModelsEndpointClient; +use codex_models_manager::manager::ModelsEndpointFuture; +use codex_models_manager::manager::OpenAiModelsManager; +use codex_models_manager::manager::SharedModelsManager; +use codex_protocol::error::CodexErr; +use codex_protocol::error::Result as CoreResult; +use codex_protocol::openai_models::ModelInfo; +use pretty_assertions::assert_eq; +use tempfile::tempdir; +use tokio::sync::Notify; + +use super::*; + +#[derive(Debug)] +struct TestModelsEndpoint { + fetch_count: AtomicUsize, + fetched: Notify, + release_second_fetch: Notify, +} + +impl TestModelsEndpoint { + fn new() -> Arc { + Arc::new(Self { + fetch_count: AtomicUsize::new(0), + fetched: Notify::new(), + release_second_fetch: Notify::new(), + }) + } + + async fn wait_for_fetch_count(&self, expected: usize) { + tokio::time::timeout(Duration::from_secs(1), async { + while self.fetch_count.load(Ordering::SeqCst) < expected { + self.fetched.notified().await; + } + }) + .await + .unwrap_or_else(|_| panic!("expected {expected} model fetches")); + } +} + +impl ModelsEndpointClient for TestModelsEndpoint { + fn has_command_auth(&self) -> bool { + true + } + + fn uses_codex_backend(&self) -> ModelsEndpointFuture<'_, bool> { + Box::pin(async { false }) + } + + fn list_models<'a>( + &'a self, + _client_version: &'a str, + ) -> ModelsEndpointFuture<'a, CoreResult<(Vec, Option)>> { + Box::pin(async move { + let fetch_index = self.fetch_count.fetch_add(1, Ordering::SeqCst); + self.fetched.notify_one(); + if fetch_index == 0 { + return Err(CodexErr::Io(std::io::Error::other("test failure"))); + } + if fetch_index == 1 { + self.release_second_fetch.notified().await; + } + Ok((Vec::new(), None)) + }) + } +} + +#[tokio::test] +async fn refreshes_immediately_periodically_and_stops_when_dropped() { + let codex_home = tempdir().expect("temp dir"); + let endpoint = TestModelsEndpoint::new(); + let models_manager: SharedModelsManager = Arc::new(OpenAiModelsManager::new( + codex_home.path().to_path_buf(), + endpoint.clone(), + /*auth_manager*/ None, + )); + let worker = spawn_with_interval(&models_manager, Duration::from_millis(10)); + + endpoint.wait_for_fetch_count(/*expected*/ 2).await; + drop(worker); + endpoint.release_second_fetch.notify_one(); + tokio::time::sleep(Duration::from_millis(30)).await; + + assert_eq!(endpoint.fetch_count.load(Ordering::SeqCst), 2); +} diff --git a/codex-rs/app-server/tests/suite/v2/attestation.rs b/codex-rs/app-server/tests/suite/v2/attestation.rs index 567f37397..d00761473 100644 --- a/codex-rs/app-server/tests/suite/v2/attestation.rs +++ b/codex-rs/app-server/tests/suite/v2/attestation.rs @@ -4,6 +4,7 @@ use app_test_support::ChatGptAuthFixture; use app_test_support::TestAppServer; use app_test_support::to_response; use app_test_support::write_chatgpt_auth; +use app_test_support::write_models_cache; use codex_app_server_protocol::AttestationGenerateResponse; use codex_app_server_protocol::ClientInfo; use codex_app_server_protocol::InitializeCapabilities; @@ -36,36 +37,26 @@ async fn attestation_generate_round_trip_adds_header_to_responses_websocket_hand { skip_if_no_network!(Ok(())); - let websocket_server = start_websocket_server_with_headers(vec![ - // App-server refreshes `/models` over HTTP during thread startup. It points at the same - // local test base URL, so let that non-websocket probe consume one connection before the - // websocket handshake under test arrives. - WebSocketConnectionConfig { - requests: Vec::new(), - response_headers: Vec::new(), - accept_delay: None, - close_after_requests: true, - }, - WebSocketConnectionConfig { - requests: vec![ - vec![ - responses::ev_response_created("warm-1"), - responses::ev_completed("warm-1"), - ], - vec![ - responses::ev_response_created("resp-1"), - responses::ev_assistant_message("msg-1", "Done"), - responses::ev_completed("resp-1"), - ], + let websocket_server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig { + requests: vec![ + vec![ + responses::ev_response_created("warm-1"), + responses::ev_completed("warm-1"), ], - response_headers: Vec::new(), - accept_delay: None, - close_after_requests: true, - }, - ]) + vec![ + responses::ev_response_created("resp-1"), + responses::ev_assistant_message("msg-1", "Done"), + responses::ev_completed("resp-1"), + ], + ], + response_headers: Vec::new(), + accept_delay: None, + close_after_requests: true, + }]) .await; let codex_home = TempDir::new()?; + write_models_cache(codex_home.path())?; create_chatgpt_websocket_config( codex_home.path(), &websocket_server.uri().replacen("ws://", "http://", 1), diff --git a/codex-rs/app-server/tests/suite/v2/remote_control.rs b/codex-rs/app-server/tests/suite/v2/remote_control.rs index 43076d164..30f7f73be 100644 --- a/codex-rs/app-server/tests/suite/v2/remote_control.rs +++ b/codex-rs/app-server/tests/suite/v2/remote_control.rs @@ -271,7 +271,15 @@ async fn listen_off_honors_persisted_remote_control_enable() -> Result<()> { .await?; let _app_server = TestAppServer::new_with_args(codex_home.path(), &["--listen", "off"]).await?; - timeout(STARTUP_TIMEOUT, listener.accept()).await??; + let request = timeout(STARTUP_TIMEOUT, read_http_request(&listener)).await??; + assert!( + request + .request_line + .starts_with("GET /backend-api/wham/remote/control/server ") + || request + .request_line + .starts_with("POST /backend-api/wham/remote/control/server/refresh ") + ); Ok(()) } @@ -852,9 +860,15 @@ impl BlockingRemoteControlBackend { { return; } - let Ok(_websocket) = listener.accept().await else { + let Ok(request) = read_http_request(&listener).await else { return; }; + if !request + .request_line + .starts_with("GET /backend-api/wham/remote/control/server ") + { + return; + } std::future::pending::<()>().await; } Err(err) => { @@ -1030,36 +1044,44 @@ async fn read_enroll_request(listener: &TcpListener) -> Result<(String, BufReade } async fn read_http_request(listener: &TcpListener) -> Result { - let (stream, _) = listener.accept().await?; - let mut reader = BufReader::new(stream); - - let mut request_line = String::new(); - reader.read_line(&mut request_line).await?; - let mut content_length = 0; loop { - let mut line = String::new(); - reader.read_line(&mut line).await?; - if line == "\r\n" { - break; - } - if let Some(value) = line - .trim_end() - .strip_prefix("content-length:") - .or_else(|| line.trim_end().strip_prefix("Content-Length:")) - { - content_length = value.trim().parse::()?; - } - } - let mut body = vec![0; content_length]; - if content_length > 0 { - reader.read_exact(&mut body).await?; - } + let (stream, _) = listener.accept().await?; + let mut reader = BufReader::new(stream); - Ok(HttpRequest { - request_line: request_line.trim_end().to_string(), - body: String::from_utf8(body)?, - reader, - }) + let mut request_line = String::new(); + reader.read_line(&mut request_line).await?; + let mut content_length = 0; + loop { + let mut line = String::new(); + reader.read_line(&mut line).await?; + if line == "\r\n" { + break; + } + if let Some(value) = line + .trim_end() + .strip_prefix("content-length:") + .or_else(|| line.trim_end().strip_prefix("Content-Length:")) + { + content_length = value.trim().parse::()?; + } + } + let mut body = vec![0; content_length]; + if content_length > 0 { + reader.read_exact(&mut body).await?; + } + + let request_line = request_line.trim_end().to_string(); + if request_line.starts_with("GET ") && request_line.contains("/v1/models?") { + respond_with_json(reader.into_inner(), serde_json::json!({ "models": [] })).await?; + continue; + } + + return Ok(HttpRequest { + request_line, + body: String::from_utf8(body)?, + reader, + }); + } } async fn respond_with_json(stream: TcpStream, body: serde_json::Value) -> Result<()> { diff --git a/codex-rs/core/tests/common/responses.rs b/codex-rs/core/tests/common/responses.rs index 9eb0c6f30..b071535c5 100644 --- a/codex-rs/core/tests/common/responses.rs +++ b/codex-rs/core/tests/common/responses.rs @@ -1237,9 +1237,11 @@ pub async fn start_websocket_server_with_headers( Ok(value) => value, Err(_) => return, }; + // Ordinary HTTP probes can share this listener with websocket tests. Only a + // successful websocket handshake should consume a scripted connection. let connection = { - let mut pending = connections.lock().unwrap(); - pending.pop_front() + let pending = connections.lock().unwrap(); + pending.front().cloned() }; let Some(connection) = connection else { @@ -1291,6 +1293,7 @@ pub async fn start_websocket_server_with_headers( Ok(ws) => ws, Err(_) => continue, }; + connections.lock().unwrap().pop_front(); let connection_index = { let mut log = requests.lock().unwrap();