app-server: keep the model cache warm (#28699)

## Why

The app server is long-lived, but its shared model cache otherwise
refreshes only when a caller needs it. Once the five-minute cache
expires, starting a thread or calling `model/list` can wait for
`/models` on the request path.

Refresh the cache in the background before it expires so foreground
callers normally use fresh local state.

## What changed

- Start an app-server worker that refreshes models immediately and then
every three minutes using the existing models-manager API.
- Hold only a weak reference to the models manager between refreshes, so
the worker does not extend its lifetime.
- Stop scheduling refreshes when the app-server lifecycle handle is shut
down or dropped. A refresh already in progress is allowed to finish.
- Adjust affected app-server test fixtures to distinguish the background
`/models` probe from the connection they are testing.

The existing models-manager cache, refresh strategies, auth handling,
ETag behavior, and concurrency semantics are unchanged.

## Testing

-
`models_refresh_worker::tests::refreshes_immediately_periodically_and_stops_when_dropped`
-
`suite::v2::remote_control::listen_off_honors_persisted_remote_control_enable`
-
`suite::v2::attestation::attestation_generate_round_trip_adds_header_to_responses_websocket_handshake`
This commit is contained in:
jif
2026-06-17 15:18:39 +01:00
committed by GitHub
Unverified
parent 45f603302c
commit 5935a90619
7 changed files with 238 additions and 58 deletions
+1
View File
@@ -101,6 +101,7 @@ pub mod in_process;
mod mcp_refresh;
mod message_processor;
mod models;
mod models_refresh_worker;
mod outgoing_message;
mod request_processors;
mod request_serialization;
@@ -92,6 +92,8 @@ use tokio::time::timeout;
use tokio_util::sync::CancellationToken;
use tracing::Instrument;
use crate::models_refresh_worker::ModelsRefreshWorker;
const EXTERNAL_AUTH_REFRESH_TIMEOUT: Duration = Duration::from_secs(10);
const CONNECTION_RPC_DRAIN_TIMEOUT: Duration = Duration::from_secs(/*secs*/ 30);
@@ -182,6 +184,7 @@ impl ExternalAuth for ExternalAuthRefreshBridge {
pub(crate) struct MessageProcessor {
outgoing: Arc<OutgoingMessageSender>,
models_refresh_worker: ModelsRefreshWorker,
skills_watcher: Arc<SkillsWatcher>,
account_processor: AccountRequestProcessor,
apps_processor: AppsRequestProcessor,
@@ -371,6 +374,8 @@ impl MessageProcessor {
)),
)
});
let models_manager = thread_manager.get_models_manager();
let models_refresh_worker = crate::models_refresh_worker::spawn(&models_manager);
thread_manager
.plugins_manager()
.set_analytics_events_client(analytics_events_client.clone());
@@ -537,6 +542,7 @@ impl MessageProcessor {
Self {
outgoing,
models_refresh_worker,
skills_watcher,
account_processor,
apps_processor,
@@ -566,6 +572,7 @@ impl MessageProcessor {
pub(crate) fn clear_runtime_references(&self) {
self.account_processor.clear_external_auth();
self.apps_processor.shutdown();
self.models_refresh_worker.shutdown();
self.skills_watcher.shutdown();
}
@@ -742,6 +749,7 @@ impl MessageProcessor {
}
pub(crate) async fn drain_background_tasks(&self) {
self.models_refresh_worker.shutdown();
self.thread_processor.drain_background_tasks().await;
}
@@ -0,0 +1,65 @@
use std::sync::Arc;
use std::time::Duration;
use codex_models_manager::manager::RefreshStrategy;
use codex_models_manager::manager::SharedModelsManager;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
const MODELS_REFRESH_INTERVAL: Duration = Duration::from_secs(3 * 60);
#[derive(Debug)]
pub(crate) struct ModelsRefreshWorker {
shutdown: CancellationToken,
_task: JoinHandle<()>,
}
impl ModelsRefreshWorker {
pub(crate) fn shutdown(&self) {
self.shutdown.cancel();
}
}
impl Drop for ModelsRefreshWorker {
fn drop(&mut self) {
self.shutdown();
}
}
pub(crate) fn spawn(models_manager: &SharedModelsManager) -> ModelsRefreshWorker {
spawn_with_interval(models_manager, MODELS_REFRESH_INTERVAL)
}
fn spawn_with_interval(
models_manager: &SharedModelsManager,
refresh_interval: Duration,
) -> ModelsRefreshWorker {
let models_manager = Arc::downgrade(models_manager);
let shutdown = CancellationToken::new();
let worker_shutdown = shutdown.clone();
let task = tokio::spawn(async move {
loop {
if worker_shutdown.is_cancelled() {
break;
}
let Some(models_manager) = models_manager.upgrade() else {
break;
};
models_manager.list_models(RefreshStrategy::Online).await;
drop(models_manager);
tokio::select! {
_ = worker_shutdown.cancelled() => break,
_ = tokio::time::sleep(refresh_interval) => {}
}
}
});
ModelsRefreshWorker {
shutdown,
_task: task,
}
}
#[cfg(test)]
#[path = "models_refresh_worker_tests.rs"]
mod tests;
@@ -0,0 +1,90 @@
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::time::Duration;
use codex_models_manager::manager::ModelsEndpointClient;
use codex_models_manager::manager::ModelsEndpointFuture;
use codex_models_manager::manager::OpenAiModelsManager;
use codex_models_manager::manager::SharedModelsManager;
use codex_protocol::error::CodexErr;
use codex_protocol::error::Result as CoreResult;
use codex_protocol::openai_models::ModelInfo;
use pretty_assertions::assert_eq;
use tempfile::tempdir;
use tokio::sync::Notify;
use super::*;
#[derive(Debug)]
struct TestModelsEndpoint {
fetch_count: AtomicUsize,
fetched: Notify,
release_second_fetch: Notify,
}
impl TestModelsEndpoint {
fn new() -> Arc<Self> {
Arc::new(Self {
fetch_count: AtomicUsize::new(0),
fetched: Notify::new(),
release_second_fetch: Notify::new(),
})
}
async fn wait_for_fetch_count(&self, expected: usize) {
tokio::time::timeout(Duration::from_secs(1), async {
while self.fetch_count.load(Ordering::SeqCst) < expected {
self.fetched.notified().await;
}
})
.await
.unwrap_or_else(|_| panic!("expected {expected} model fetches"));
}
}
impl ModelsEndpointClient for TestModelsEndpoint {
fn has_command_auth(&self) -> bool {
true
}
fn uses_codex_backend(&self) -> ModelsEndpointFuture<'_, bool> {
Box::pin(async { false })
}
fn list_models<'a>(
&'a self,
_client_version: &'a str,
) -> ModelsEndpointFuture<'a, CoreResult<(Vec<ModelInfo>, Option<String>)>> {
Box::pin(async move {
let fetch_index = self.fetch_count.fetch_add(1, Ordering::SeqCst);
self.fetched.notify_one();
if fetch_index == 0 {
return Err(CodexErr::Io(std::io::Error::other("test failure")));
}
if fetch_index == 1 {
self.release_second_fetch.notified().await;
}
Ok((Vec::new(), None))
})
}
}
#[tokio::test]
async fn refreshes_immediately_periodically_and_stops_when_dropped() {
let codex_home = tempdir().expect("temp dir");
let endpoint = TestModelsEndpoint::new();
let models_manager: SharedModelsManager = Arc::new(OpenAiModelsManager::new(
codex_home.path().to_path_buf(),
endpoint.clone(),
/*auth_manager*/ None,
));
let worker = spawn_with_interval(&models_manager, Duration::from_millis(10));
endpoint.wait_for_fetch_count(/*expected*/ 2).await;
drop(worker);
endpoint.release_second_fetch.notify_one();
tokio::time::sleep(Duration::from_millis(30)).await;
assert_eq!(endpoint.fetch_count.load(Ordering::SeqCst), 2);
}
@@ -4,6 +4,7 @@ use app_test_support::ChatGptAuthFixture;
use app_test_support::TestAppServer;
use app_test_support::to_response;
use app_test_support::write_chatgpt_auth;
use app_test_support::write_models_cache;
use codex_app_server_protocol::AttestationGenerateResponse;
use codex_app_server_protocol::ClientInfo;
use codex_app_server_protocol::InitializeCapabilities;
@@ -36,36 +37,26 @@ async fn attestation_generate_round_trip_adds_header_to_responses_websocket_hand
{
skip_if_no_network!(Ok(()));
let websocket_server = start_websocket_server_with_headers(vec![
// App-server refreshes `/models` over HTTP during thread startup. It points at the same
// local test base URL, so let that non-websocket probe consume one connection before the
// websocket handshake under test arrives.
WebSocketConnectionConfig {
requests: Vec::new(),
response_headers: Vec::new(),
accept_delay: None,
close_after_requests: true,
},
WebSocketConnectionConfig {
requests: vec![
vec![
responses::ev_response_created("warm-1"),
responses::ev_completed("warm-1"),
],
vec![
responses::ev_response_created("resp-1"),
responses::ev_assistant_message("msg-1", "Done"),
responses::ev_completed("resp-1"),
],
let websocket_server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig {
requests: vec![
vec![
responses::ev_response_created("warm-1"),
responses::ev_completed("warm-1"),
],
response_headers: Vec::new(),
accept_delay: None,
close_after_requests: true,
},
])
vec![
responses::ev_response_created("resp-1"),
responses::ev_assistant_message("msg-1", "Done"),
responses::ev_completed("resp-1"),
],
],
response_headers: Vec::new(),
accept_delay: None,
close_after_requests: true,
}])
.await;
let codex_home = TempDir::new()?;
write_models_cache(codex_home.path())?;
create_chatgpt_websocket_config(
codex_home.path(),
&websocket_server.uri().replacen("ws://", "http://", 1),
@@ -271,7 +271,15 @@ async fn listen_off_honors_persisted_remote_control_enable() -> Result<()> {
.await?;
let _app_server = TestAppServer::new_with_args(codex_home.path(), &["--listen", "off"]).await?;
timeout(STARTUP_TIMEOUT, listener.accept()).await??;
let request = timeout(STARTUP_TIMEOUT, read_http_request(&listener)).await??;
assert!(
request
.request_line
.starts_with("GET /backend-api/wham/remote/control/server ")
|| request
.request_line
.starts_with("POST /backend-api/wham/remote/control/server/refresh ")
);
Ok(())
}
@@ -852,9 +860,15 @@ impl BlockingRemoteControlBackend {
{
return;
}
let Ok(_websocket) = listener.accept().await else {
let Ok(request) = read_http_request(&listener).await else {
return;
};
if !request
.request_line
.starts_with("GET /backend-api/wham/remote/control/server ")
{
return;
}
std::future::pending::<()>().await;
}
Err(err) => {
@@ -1030,36 +1044,44 @@ async fn read_enroll_request(listener: &TcpListener) -> Result<(String, BufReade
}
async fn read_http_request(listener: &TcpListener) -> Result<HttpRequest> {
let (stream, _) = listener.accept().await?;
let mut reader = BufReader::new(stream);
let mut request_line = String::new();
reader.read_line(&mut request_line).await?;
let mut content_length = 0;
loop {
let mut line = String::new();
reader.read_line(&mut line).await?;
if line == "\r\n" {
break;
}
if let Some(value) = line
.trim_end()
.strip_prefix("content-length:")
.or_else(|| line.trim_end().strip_prefix("Content-Length:"))
{
content_length = value.trim().parse::<usize>()?;
}
}
let mut body = vec![0; content_length];
if content_length > 0 {
reader.read_exact(&mut body).await?;
}
let (stream, _) = listener.accept().await?;
let mut reader = BufReader::new(stream);
Ok(HttpRequest {
request_line: request_line.trim_end().to_string(),
body: String::from_utf8(body)?,
reader,
})
let mut request_line = String::new();
reader.read_line(&mut request_line).await?;
let mut content_length = 0;
loop {
let mut line = String::new();
reader.read_line(&mut line).await?;
if line == "\r\n" {
break;
}
if let Some(value) = line
.trim_end()
.strip_prefix("content-length:")
.or_else(|| line.trim_end().strip_prefix("Content-Length:"))
{
content_length = value.trim().parse::<usize>()?;
}
}
let mut body = vec![0; content_length];
if content_length > 0 {
reader.read_exact(&mut body).await?;
}
let request_line = request_line.trim_end().to_string();
if request_line.starts_with("GET ") && request_line.contains("/v1/models?") {
respond_with_json(reader.into_inner(), serde_json::json!({ "models": [] })).await?;
continue;
}
return Ok(HttpRequest {
request_line,
body: String::from_utf8(body)?,
reader,
});
}
}
async fn respond_with_json(stream: TcpStream, body: serde_json::Value) -> Result<()> {
+5 -2
View File
@@ -1237,9 +1237,11 @@ pub async fn start_websocket_server_with_headers(
Ok(value) => value,
Err(_) => return,
};
// Ordinary HTTP probes can share this listener with websocket tests. Only a
// successful websocket handshake should consume a scripted connection.
let connection = {
let mut pending = connections.lock().unwrap();
pending.pop_front()
let pending = connections.lock().unwrap();
pending.front().cloned()
};
let Some(connection) = connection else {
@@ -1291,6 +1293,7 @@ pub async fn start_websocket_server_with_headers(
Ok(ws) => ws,
Err(_) => continue,
};
connections.lock().unwrap().pop_front();
let connection_index = {
let mut log = requests.lock().unwrap();