mirror of
https://github.com/pchuan98/codex.git
synced 2026-07-01 00:31:56 +08:00
app-server: keep the model cache warm (#28699)
## Why The app server is long-lived, but its shared model cache otherwise refreshes only when a caller needs it. Once the five-minute cache expires, starting a thread or calling `model/list` can wait for `/models` on the request path. Refresh the cache in the background before it expires so foreground callers normally use fresh local state. ## What changed - Start an app-server worker that refreshes models immediately and then every three minutes using the existing models-manager API. - Hold only a weak reference to the models manager between refreshes, so the worker does not extend its lifetime. - Stop scheduling refreshes when the app-server lifecycle handle is shut down or dropped. A refresh already in progress is allowed to finish. - Adjust affected app-server test fixtures to distinguish the background `/models` probe from the connection they are testing. The existing models-manager cache, refresh strategies, auth handling, ETag behavior, and concurrency semantics are unchanged. ## Testing - `models_refresh_worker::tests::refreshes_immediately_periodically_and_stops_when_dropped` - `suite::v2::remote_control::listen_off_honors_persisted_remote_control_enable` - `suite::v2::attestation::attestation_generate_round_trip_adds_header_to_responses_websocket_handshake`
This commit is contained in:
@@ -101,6 +101,7 @@ pub mod in_process;
|
||||
mod mcp_refresh;
|
||||
mod message_processor;
|
||||
mod models;
|
||||
mod models_refresh_worker;
|
||||
mod outgoing_message;
|
||||
mod request_processors;
|
||||
mod request_serialization;
|
||||
|
||||
@@ -92,6 +92,8 @@ use tokio::time::timeout;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::Instrument;
|
||||
|
||||
use crate::models_refresh_worker::ModelsRefreshWorker;
|
||||
|
||||
const EXTERNAL_AUTH_REFRESH_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const CONNECTION_RPC_DRAIN_TIMEOUT: Duration = Duration::from_secs(/*secs*/ 30);
|
||||
|
||||
@@ -182,6 +184,7 @@ impl ExternalAuth for ExternalAuthRefreshBridge {
|
||||
|
||||
pub(crate) struct MessageProcessor {
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
models_refresh_worker: ModelsRefreshWorker,
|
||||
skills_watcher: Arc<SkillsWatcher>,
|
||||
account_processor: AccountRequestProcessor,
|
||||
apps_processor: AppsRequestProcessor,
|
||||
@@ -371,6 +374,8 @@ impl MessageProcessor {
|
||||
)),
|
||||
)
|
||||
});
|
||||
let models_manager = thread_manager.get_models_manager();
|
||||
let models_refresh_worker = crate::models_refresh_worker::spawn(&models_manager);
|
||||
thread_manager
|
||||
.plugins_manager()
|
||||
.set_analytics_events_client(analytics_events_client.clone());
|
||||
@@ -537,6 +542,7 @@ impl MessageProcessor {
|
||||
|
||||
Self {
|
||||
outgoing,
|
||||
models_refresh_worker,
|
||||
skills_watcher,
|
||||
account_processor,
|
||||
apps_processor,
|
||||
@@ -566,6 +572,7 @@ impl MessageProcessor {
|
||||
pub(crate) fn clear_runtime_references(&self) {
|
||||
self.account_processor.clear_external_auth();
|
||||
self.apps_processor.shutdown();
|
||||
self.models_refresh_worker.shutdown();
|
||||
self.skills_watcher.shutdown();
|
||||
}
|
||||
|
||||
@@ -742,6 +749,7 @@ impl MessageProcessor {
|
||||
}
|
||||
|
||||
pub(crate) async fn drain_background_tasks(&self) {
|
||||
self.models_refresh_worker.shutdown();
|
||||
self.thread_processor.drain_background_tasks().await;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_models_manager::manager::RefreshStrategy;
|
||||
use codex_models_manager::manager::SharedModelsManager;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
const MODELS_REFRESH_INTERVAL: Duration = Duration::from_secs(3 * 60);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ModelsRefreshWorker {
|
||||
shutdown: CancellationToken,
|
||||
_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl ModelsRefreshWorker {
|
||||
pub(crate) fn shutdown(&self) {
|
||||
self.shutdown.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ModelsRefreshWorker {
|
||||
fn drop(&mut self) {
|
||||
self.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn spawn(models_manager: &SharedModelsManager) -> ModelsRefreshWorker {
|
||||
spawn_with_interval(models_manager, MODELS_REFRESH_INTERVAL)
|
||||
}
|
||||
|
||||
fn spawn_with_interval(
|
||||
models_manager: &SharedModelsManager,
|
||||
refresh_interval: Duration,
|
||||
) -> ModelsRefreshWorker {
|
||||
let models_manager = Arc::downgrade(models_manager);
|
||||
let shutdown = CancellationToken::new();
|
||||
let worker_shutdown = shutdown.clone();
|
||||
let task = tokio::spawn(async move {
|
||||
loop {
|
||||
if worker_shutdown.is_cancelled() {
|
||||
break;
|
||||
}
|
||||
let Some(models_manager) = models_manager.upgrade() else {
|
||||
break;
|
||||
};
|
||||
models_manager.list_models(RefreshStrategy::Online).await;
|
||||
drop(models_manager);
|
||||
|
||||
tokio::select! {
|
||||
_ = worker_shutdown.cancelled() => break,
|
||||
_ = tokio::time::sleep(refresh_interval) => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
ModelsRefreshWorker {
|
||||
shutdown,
|
||||
_task: task,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "models_refresh_worker_tests.rs"]
|
||||
mod tests;
|
||||
@@ -0,0 +1,90 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_models_manager::manager::ModelsEndpointClient;
|
||||
use codex_models_manager::manager::ModelsEndpointFuture;
|
||||
use codex_models_manager::manager::OpenAiModelsManager;
|
||||
use codex_models_manager::manager::SharedModelsManager;
|
||||
use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::error::Result as CoreResult;
|
||||
use codex_protocol::openai_models::ModelInfo;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::tempdir;
|
||||
use tokio::sync::Notify;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TestModelsEndpoint {
|
||||
fetch_count: AtomicUsize,
|
||||
fetched: Notify,
|
||||
release_second_fetch: Notify,
|
||||
}
|
||||
|
||||
impl TestModelsEndpoint {
|
||||
fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
fetch_count: AtomicUsize::new(0),
|
||||
fetched: Notify::new(),
|
||||
release_second_fetch: Notify::new(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn wait_for_fetch_count(&self, expected: usize) {
|
||||
tokio::time::timeout(Duration::from_secs(1), async {
|
||||
while self.fetch_count.load(Ordering::SeqCst) < expected {
|
||||
self.fetched.notified().await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap_or_else(|_| panic!("expected {expected} model fetches"));
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelsEndpointClient for TestModelsEndpoint {
|
||||
fn has_command_auth(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn uses_codex_backend(&self) -> ModelsEndpointFuture<'_, bool> {
|
||||
Box::pin(async { false })
|
||||
}
|
||||
|
||||
fn list_models<'a>(
|
||||
&'a self,
|
||||
_client_version: &'a str,
|
||||
) -> ModelsEndpointFuture<'a, CoreResult<(Vec<ModelInfo>, Option<String>)>> {
|
||||
Box::pin(async move {
|
||||
let fetch_index = self.fetch_count.fetch_add(1, Ordering::SeqCst);
|
||||
self.fetched.notify_one();
|
||||
if fetch_index == 0 {
|
||||
return Err(CodexErr::Io(std::io::Error::other("test failure")));
|
||||
}
|
||||
if fetch_index == 1 {
|
||||
self.release_second_fetch.notified().await;
|
||||
}
|
||||
Ok((Vec::new(), None))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refreshes_immediately_periodically_and_stops_when_dropped() {
|
||||
let codex_home = tempdir().expect("temp dir");
|
||||
let endpoint = TestModelsEndpoint::new();
|
||||
let models_manager: SharedModelsManager = Arc::new(OpenAiModelsManager::new(
|
||||
codex_home.path().to_path_buf(),
|
||||
endpoint.clone(),
|
||||
/*auth_manager*/ None,
|
||||
));
|
||||
let worker = spawn_with_interval(&models_manager, Duration::from_millis(10));
|
||||
|
||||
endpoint.wait_for_fetch_count(/*expected*/ 2).await;
|
||||
drop(worker);
|
||||
endpoint.release_second_fetch.notify_one();
|
||||
tokio::time::sleep(Duration::from_millis(30)).await;
|
||||
|
||||
assert_eq!(endpoint.fetch_count.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
@@ -4,6 +4,7 @@ use app_test_support::ChatGptAuthFixture;
|
||||
use app_test_support::TestAppServer;
|
||||
use app_test_support::to_response;
|
||||
use app_test_support::write_chatgpt_auth;
|
||||
use app_test_support::write_models_cache;
|
||||
use codex_app_server_protocol::AttestationGenerateResponse;
|
||||
use codex_app_server_protocol::ClientInfo;
|
||||
use codex_app_server_protocol::InitializeCapabilities;
|
||||
@@ -36,36 +37,26 @@ async fn attestation_generate_round_trip_adds_header_to_responses_websocket_hand
|
||||
{
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let websocket_server = start_websocket_server_with_headers(vec![
|
||||
// App-server refreshes `/models` over HTTP during thread startup. It points at the same
|
||||
// local test base URL, so let that non-websocket probe consume one connection before the
|
||||
// websocket handshake under test arrives.
|
||||
WebSocketConnectionConfig {
|
||||
requests: Vec::new(),
|
||||
response_headers: Vec::new(),
|
||||
accept_delay: None,
|
||||
close_after_requests: true,
|
||||
},
|
||||
WebSocketConnectionConfig {
|
||||
requests: vec![
|
||||
vec![
|
||||
responses::ev_response_created("warm-1"),
|
||||
responses::ev_completed("warm-1"),
|
||||
],
|
||||
vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_assistant_message("msg-1", "Done"),
|
||||
responses::ev_completed("resp-1"),
|
||||
],
|
||||
let websocket_server = start_websocket_server_with_headers(vec![WebSocketConnectionConfig {
|
||||
requests: vec![
|
||||
vec![
|
||||
responses::ev_response_created("warm-1"),
|
||||
responses::ev_completed("warm-1"),
|
||||
],
|
||||
response_headers: Vec::new(),
|
||||
accept_delay: None,
|
||||
close_after_requests: true,
|
||||
},
|
||||
])
|
||||
vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
responses::ev_assistant_message("msg-1", "Done"),
|
||||
responses::ev_completed("resp-1"),
|
||||
],
|
||||
],
|
||||
response_headers: Vec::new(),
|
||||
accept_delay: None,
|
||||
close_after_requests: true,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let codex_home = TempDir::new()?;
|
||||
write_models_cache(codex_home.path())?;
|
||||
create_chatgpt_websocket_config(
|
||||
codex_home.path(),
|
||||
&websocket_server.uri().replacen("ws://", "http://", 1),
|
||||
|
||||
@@ -271,7 +271,15 @@ async fn listen_off_honors_persisted_remote_control_enable() -> Result<()> {
|
||||
.await?;
|
||||
|
||||
let _app_server = TestAppServer::new_with_args(codex_home.path(), &["--listen", "off"]).await?;
|
||||
timeout(STARTUP_TIMEOUT, listener.accept()).await??;
|
||||
let request = timeout(STARTUP_TIMEOUT, read_http_request(&listener)).await??;
|
||||
assert!(
|
||||
request
|
||||
.request_line
|
||||
.starts_with("GET /backend-api/wham/remote/control/server ")
|
||||
|| request
|
||||
.request_line
|
||||
.starts_with("POST /backend-api/wham/remote/control/server/refresh ")
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -852,9 +860,15 @@ impl BlockingRemoteControlBackend {
|
||||
{
|
||||
return;
|
||||
}
|
||||
let Ok(_websocket) = listener.accept().await else {
|
||||
let Ok(request) = read_http_request(&listener).await else {
|
||||
return;
|
||||
};
|
||||
if !request
|
||||
.request_line
|
||||
.starts_with("GET /backend-api/wham/remote/control/server ")
|
||||
{
|
||||
return;
|
||||
}
|
||||
std::future::pending::<()>().await;
|
||||
}
|
||||
Err(err) => {
|
||||
@@ -1030,36 +1044,44 @@ async fn read_enroll_request(listener: &TcpListener) -> Result<(String, BufReade
|
||||
}
|
||||
|
||||
async fn read_http_request(listener: &TcpListener) -> Result<HttpRequest> {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
let mut request_line = String::new();
|
||||
reader.read_line(&mut request_line).await?;
|
||||
let mut content_length = 0;
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
reader.read_line(&mut line).await?;
|
||||
if line == "\r\n" {
|
||||
break;
|
||||
}
|
||||
if let Some(value) = line
|
||||
.trim_end()
|
||||
.strip_prefix("content-length:")
|
||||
.or_else(|| line.trim_end().strip_prefix("Content-Length:"))
|
||||
{
|
||||
content_length = value.trim().parse::<usize>()?;
|
||||
}
|
||||
}
|
||||
let mut body = vec![0; content_length];
|
||||
if content_length > 0 {
|
||||
reader.read_exact(&mut body).await?;
|
||||
}
|
||||
let (stream, _) = listener.accept().await?;
|
||||
let mut reader = BufReader::new(stream);
|
||||
|
||||
Ok(HttpRequest {
|
||||
request_line: request_line.trim_end().to_string(),
|
||||
body: String::from_utf8(body)?,
|
||||
reader,
|
||||
})
|
||||
let mut request_line = String::new();
|
||||
reader.read_line(&mut request_line).await?;
|
||||
let mut content_length = 0;
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
reader.read_line(&mut line).await?;
|
||||
if line == "\r\n" {
|
||||
break;
|
||||
}
|
||||
if let Some(value) = line
|
||||
.trim_end()
|
||||
.strip_prefix("content-length:")
|
||||
.or_else(|| line.trim_end().strip_prefix("Content-Length:"))
|
||||
{
|
||||
content_length = value.trim().parse::<usize>()?;
|
||||
}
|
||||
}
|
||||
let mut body = vec![0; content_length];
|
||||
if content_length > 0 {
|
||||
reader.read_exact(&mut body).await?;
|
||||
}
|
||||
|
||||
let request_line = request_line.trim_end().to_string();
|
||||
if request_line.starts_with("GET ") && request_line.contains("/v1/models?") {
|
||||
respond_with_json(reader.into_inner(), serde_json::json!({ "models": [] })).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
return Ok(HttpRequest {
|
||||
request_line,
|
||||
body: String::from_utf8(body)?,
|
||||
reader,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn respond_with_json(stream: TcpStream, body: serde_json::Value) -> Result<()> {
|
||||
|
||||
@@ -1237,9 +1237,11 @@ pub async fn start_websocket_server_with_headers(
|
||||
Ok(value) => value,
|
||||
Err(_) => return,
|
||||
};
|
||||
// Ordinary HTTP probes can share this listener with websocket tests. Only a
|
||||
// successful websocket handshake should consume a scripted connection.
|
||||
let connection = {
|
||||
let mut pending = connections.lock().unwrap();
|
||||
pending.pop_front()
|
||||
let pending = connections.lock().unwrap();
|
||||
pending.front().cloned()
|
||||
};
|
||||
|
||||
let Some(connection) = connection else {
|
||||
@@ -1291,6 +1293,7 @@ pub async fn start_websocket_server_with_headers(
|
||||
Ok(ws) => ws,
|
||||
Err(_) => continue,
|
||||
};
|
||||
connections.lock().unwrap().pop_front();
|
||||
|
||||
let connection_index = {
|
||||
let mut log = requests.lock().unwrap();
|
||||
|
||||
Reference in New Issue
Block a user