diff --git a/codex-rs/app-server/BUILD.bazel b/codex-rs/app-server/BUILD.bazel index 872f533be..d2d3f42a1 100644 --- a/codex-rs/app-server/BUILD.bazel +++ b/codex-rs/app-server/BUILD.bazel @@ -4,5 +4,9 @@ codex_rust_crate( name = "app-server", crate_name = "codex_app_server", integration_test_timeout = "long", + test_shard_counts = { + "app-server-all-test": 8, + "app-server-unit-tests": 8, + }, test_tags = ["no-sandbox"], ) diff --git a/codex-rs/core/BUILD.bazel b/codex-rs/core/BUILD.bazel index 434dc1f6a..dd52bce43 100644 --- a/codex-rs/core/BUILD.bazel +++ b/codex-rs/core/BUILD.bazel @@ -47,6 +47,10 @@ codex_rust_crate( # succeeds without this workaround. "//:AGENTS.md", ], + test_shard_counts = { + "core-all-test": 8, + "core-unit-tests": 8, + }, test_tags = ["no-sandbox"], unit_test_timeout = "long", extra_binaries = [ diff --git a/codex-rs/tui/BUILD.bazel b/codex-rs/tui/BUILD.bazel index 33cc13343..d0aab0805 100644 --- a/codex-rs/tui/BUILD.bazel +++ b/codex-rs/tui/BUILD.bazel @@ -24,4 +24,7 @@ codex_rust_crate( "//codex-rs/cli:codex", ], rustc_flags_extra = MACOS_WEBRTC_RUSTC_LINK_FLAGS, + test_shard_counts = { + "tui-unit-tests": 8, + }, ) diff --git a/defs.bzl b/defs.bzl index 6972b65a6..53114a577 100644 --- a/defs.bzl +++ b/defs.bzl @@ -140,6 +140,7 @@ def codex_rust_crate( integration_test_args = [], integration_test_timeout = None, test_data_extra = [], + test_shard_counts = {}, test_tags = [], unit_test_timeout = None, extra_binaries = []): @@ -174,6 +175,11 @@ def codex_rust_crate( integration_test_timeout: Optional Bazel timeout for integration test targets generated from `tests/*.rs`. test_data_extra: Extra runtime data for tests. + test_shard_counts: Mapping from generated test target name to Bazel + shard count. Matching tests use native Bazel sharding on the + original test label, while rules_rust assigns each Rust test case + to a stable bucket by hashing the test name. Matching tests are + also marked flaky, which gives them Bazel's default three attempts. test_tags: Tags applied to unit + integration test targets. Typically used to disable the sandbox, but see https://bazel.build/reference/be/common-definitions#common.tags unit_test_timeout: Optional Bazel timeout for the unit-test target @@ -246,7 +252,13 @@ def codex_rust_crate( visibility = ["//visibility:public"], ) + unit_test_name = name + "-unit-tests" unit_test_binary = name + "-unit-tests-bin" + unit_test_shard_count = _test_shard_count(test_shard_counts, unit_test_name) + unit_test_binary_kwargs = {} + if unit_test_shard_count: + unit_test_binary_kwargs["experimental_enable_sharding"] = True + rust_test( name = unit_test_binary, crate = name, @@ -265,14 +277,18 @@ def codex_rust_crate( rustc_env = rustc_env, data = test_data_extra, tags = test_tags + ["manual"], + **unit_test_binary_kwargs ) unit_test_kwargs = {} if unit_test_timeout: unit_test_kwargs["timeout"] = unit_test_timeout + if unit_test_shard_count: + unit_test_kwargs["shard_count"] = unit_test_shard_count + unit_test_kwargs["flaky"] = True workspace_root_test( - name = name + "-unit-tests", + name = unit_test_name, env = test_env, test_bin = ":" + unit_test_binary, workspace_root_marker = "//codex-rs/utils/cargo-bin:repo_root.marker", @@ -318,6 +334,14 @@ def codex_rust_crate( if not test_name.endswith("-test"): test_name += "-test" + test_kwargs = {} + test_kwargs.update(integration_test_kwargs) + test_shard_count = _test_shard_count(test_shard_counts, test_name) + if test_shard_count: + test_kwargs["experimental_enable_sharding"] = True + test_kwargs["shard_count"] = test_shard_count + test_kwargs["flaky"] = True + rust_test( name = test_name, crate_name = test_crate_name, @@ -339,5 +363,15 @@ def codex_rust_crate( # execute from the repo root and can misplace integration snapshots. env = cargo_env, tags = test_tags, - **integration_test_kwargs + **test_kwargs ) + +def _test_shard_count(test_shard_counts, test_name): + shard_count = test_shard_counts.get(test_name) + if shard_count == None: + return None + + if shard_count < 1: + fail("test_shard_counts[{}] must be a positive integer".format(test_name)) + + return shard_count