diff --git a/codex-rs/utils/image/src/image_tests.rs b/codex-rs/utils/image/src/image_tests.rs index e60c8c6be..eb60dceca 100644 --- a/codex-rs/utils/image/src/image_tests.rs +++ b/codex-rs/utils/image/src/image_tests.rs @@ -99,7 +99,7 @@ async fn returns_original_image_when_within_bounds() { assert_eq!(encoded.width, 64); assert_eq!(encoded.height, 32); assert_eq!(encoded.mime, mime); - assert_eq!(encoded.bytes, original_bytes); + assert_eq!(encoded.bytes.as_ref(), original_bytes); } } @@ -225,7 +225,7 @@ async fn preserves_large_image_in_original_mode() { assert_eq!(processed.width, 4096); assert_eq!(processed.height, 2048); assert_eq!(processed.mime, "image/png"); - assert_eq!(processed.bytes, original_bytes); + assert_eq!(processed.bytes.as_ref(), original_bytes); } #[tokio::test(flavor = "multi_thread")] @@ -242,7 +242,7 @@ async fn data_url_processing_preserves_supported_source_bytes() { assert_eq!(processed.width, 64); assert_eq!(processed.height, 32); assert_eq!(processed.mime, "image/png"); - assert_eq!(processed.bytes, original_bytes); + assert_eq!(processed.bytes.as_ref(), original_bytes); } #[tokio::test(flavor = "multi_thread")] @@ -339,3 +339,26 @@ async fn reprocesses_updated_file_contents() { assert_eq!(second.height, 48); assert_ne!(second.bytes, first.bytes); } + +#[tokio::test(flavor = "multi_thread")] +async fn bounds_cache_by_encoded_byte_size() { + let cache = ImageCache::new(NonZeroUsize::new(4).expect("non-zero cache capacity")); + let key = |digest_byte| ImageCacheKey { + digest: [digest_byte; 20], + mode: PromptImageMode::Original, + }; + let image = |size| EncodedImage { + bytes: vec![0; size].into(), + mime: "image/png".to_string(), + width: 1, + height: 1, + }; + + cache_image(&cache, key(1), image(3), /*byte_capacity*/ 5); + cache_image(&cache, key(2), image(3), /*byte_capacity*/ 5); + cache_image(&cache, key(3), image(6), /*byte_capacity*/ 5); + + assert!(cache.get(&key(1)).is_none()); + assert!(cache.get(&key(2)).is_some()); + assert!(cache.get(&key(3)).is_none()); +} diff --git a/codex-rs/utils/image/src/lib.rs b/codex-rs/utils/image/src/lib.rs index 93ae427ea..770cfc9d3 100644 --- a/codex-rs/utils/image/src/lib.rs +++ b/codex-rs/utils/image/src/lib.rs @@ -1,6 +1,7 @@ use std::io::Cursor; use std::num::NonZeroUsize; use std::path::Path; +use std::sync::Arc; use std::sync::LazyLock; use base64::Engine; @@ -28,6 +29,7 @@ pub const MAX_DIMENSION: u32 = 2048; /// This is a high sanity guard against pathological inputs, not a protocol /// requirement or target upload size. pub const MAX_PROMPT_IMAGE_INPUT_BYTES: usize = 1024 * 1024 * 1024; +const MAX_IMAGE_CACHE_BYTES: usize = 64 * 1024 * 1024; pub mod error; @@ -35,7 +37,7 @@ pub use crate::error::ImageProcessingError; #[derive(Debug, Clone)] pub struct EncodedImage { - pub bytes: Vec, + pub bytes: Arc<[u8]>, pub mime: String, pub width: u32, pub height: u32, @@ -77,7 +79,9 @@ struct ImageCacheKey { mode: PromptImageMode, } -static IMAGE_CACHE: LazyLock> = +type ImageCache = BlockingLruCache; + +static IMAGE_CACHE: LazyLock = LazyLock::new(|| BlockingLruCache::new(NonZeroUsize::new(32).unwrap_or(NonZeroUsize::MIN))); pub fn load_for_prompt_bytes( @@ -92,7 +96,11 @@ pub fn load_for_prompt_bytes( mode, }; - IMAGE_CACHE.get_or_try_insert_with(key, move || { + if let Some(image) = IMAGE_CACHE.get(&key) { + return Ok(image); + } + + let image = (move || { let guessed_format = image::guess_format(&file_bytes) .map_err(|source| ImageProcessingError::decode_error(&path_buf, source))?; let format = match guessed_format { @@ -151,7 +159,7 @@ pub fn load_for_prompt_bytes( let (bytes, output_format) = encode_image(&resized, target_format, metadata)?; let mime = format_to_mime(output_format); EncodedImage { - bytes, + bytes: bytes.into(), mime, width, height, @@ -160,7 +168,7 @@ pub fn load_for_prompt_bytes( if let Some(format) = format.filter(|format| can_preserve_source_bytes(*format)) { let mime = format_to_mime(format); EncodedImage { - bytes: file_bytes, + bytes: file_bytes.into(), mime, width, height, @@ -169,7 +177,7 @@ pub fn load_for_prompt_bytes( let (bytes, output_format) = encode_image(&dynamic, ImageFormat::Png, metadata)?; let mime = format_to_mime(output_format); EncodedImage { - bytes, + bytes: bytes.into(), mime, width, height, @@ -178,7 +186,30 @@ pub fn load_for_prompt_bytes( }; Ok(encoded) - }) + })()?; + + cache_image(&IMAGE_CACHE, key, image.clone(), MAX_IMAGE_CACHE_BYTES); + Ok(image) +} + +fn cache_image(cache: &ImageCache, key: ImageCacheKey, image: EncodedImage, byte_capacity: usize) { + if image.bytes.len() > byte_capacity { + return; + } + + cache.with_mut(|cache| { + cache.put(key, image); + let mut cached_bytes = cache + .iter() + .map(|(_, image)| image.bytes.len()) + .sum::(); + while cached_bytes > byte_capacity { + let Some((_, evicted)) = cache.pop_lru() else { + break; + }; + cached_bytes -= evicted.bytes.len(); + } + }); } pub fn load_data_url_for_prompt(