1
0
mirror of https://github.com/home-assistant/core.git synced 2026-06-02 05:34:15 +01:00
Files
core/script/split_tests.py
T
J. Nick Koston 9ed16b63a3 Cache per-file test counts in split_tests
Persist the result of pytest --collect-only between CI runs as a JSON
file keyed by content hash, so unchanged test files are served from
cache and only edited or new files are re-collected.  The cache is
self-healing:

* Missing, corrupt, or wrong-version files fall back to a full collect.
* Any conftest.py change anywhere under the test root invalidates the
  whole cache, so fixture parametrization shifts cannot silently skew
  counts.
* Files pytest returns nothing for (helper modules named test_*.py with
  no test functions) are cached as zero so they don't get re-collected
  forever.

Walking is done once with os.walk (~2x faster than Path.rglob) and
collects test files plus conftests in a single pass.  When the cache
is fully cold we feed pytest top-level directories rather than
thousands of file paths so cold runs stay as fast as before the cache
landed.

Wire the new --cache flag through the prepare-pytest-full job and back
the cache file with actions/cache so PRs can pick up the latest dev
snapshot via restore-keys.  Local timings: cold 11s, warm with no diff
0.4s, warm with one file edited 2.3s.
2026-05-21 15:56:08 -05:00

513 lines
18 KiB
Python
Executable File

#!/usr/bin/env python3
"""Helper script to split test into n buckets."""
import argparse
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass, field
import hashlib
import json
from math import ceil
import os
from pathlib import Path
import subprocess
import sys
from typing import Final
# tests/components has ~1000 sub-directories, which makes it the natural
# place to subdivide to keep each pytest invocation roughly equal in size.
_FAN_OUT_DIRS: Final = frozenset({"components"})
# Cache file format version; bump on any incompatible schema change so old
# caches are ignored rather than misread.
_CACHE_VERSION: Final = 1
class Bucket:
"""Class to hold bucket."""
def __init__(self) -> None:
"""Initialize bucket."""
self.total_tests = 0
self._paths: list[str] = []
def add(self, part: TestFolder | TestFile) -> None:
"""Add tests to bucket."""
part.add_to_bucket()
self.total_tests += part.total_tests
self._paths.append(str(part.path))
def get_paths_line(self) -> str:
"""Return paths."""
return " ".join(self._paths) + "\n"
class BucketHolder:
"""Class to hold buckets."""
def __init__(self, tests_per_bucket: int, bucket_count: int) -> None:
"""Initialize bucket holder."""
self._tests_per_bucket = tests_per_bucket
self._bucket_count = bucket_count
self._buckets: list[Bucket] = [Bucket() for _ in range(bucket_count)]
def split_tests(self, test_folder: TestFolder) -> None:
"""Split tests into buckets."""
digits = len(str(test_folder.total_tests))
sorted_tests = sorted(
test_folder.get_all_flatten(), reverse=True, key=lambda x: x.total_tests
)
for tests in sorted_tests:
if tests.added_to_bucket:
# Already added to bucket
continue
print(f"{tests.total_tests:>{digits}} tests in {tests.path}")
smallest_bucket = min(self._buckets, key=lambda x: x.total_tests)
is_file = isinstance(tests, TestFile)
if (
smallest_bucket.total_tests + tests.total_tests < self._tests_per_bucket
) or is_file:
smallest_bucket.add(tests)
# Ensure all files from the same folder are in the same bucket
# to ensure that syrupy correctly identifies unused snapshots
if is_file:
for other_test in tests.parent.children.values():
if other_test is tests or isinstance(other_test, TestFolder):
continue
print(
f"{other_test.total_tests:>{digits}}"
f" tests in {other_test.path}"
" (same bucket)"
)
smallest_bucket.add(other_test)
# verify that all tests are added to a bucket
if not test_folder.added_to_bucket:
raise ValueError("Not all tests are added to a bucket")
def create_ouput_file(self) -> None:
"""Create output file."""
with Path("pytest_buckets.txt").open("w", encoding="utf-8") as file:
for idx, bucket in enumerate(self._buckets):
print(f"Bucket {idx + 1} has {bucket.total_tests} tests")
file.write(bucket.get_paths_line())
@dataclass
class TestFile:
"""Class represents a single test file and the number of tests it has."""
total_tests: int
path: Path
added_to_bucket: bool = field(default=False, init=False)
parent: TestFolder | None = field(default=None, init=False)
def add_to_bucket(self) -> None:
"""Add test file to bucket."""
if self.added_to_bucket:
raise ValueError("Already added to bucket")
self.added_to_bucket = True
def __gt__(self, other: TestFile) -> bool:
"""Return if greater than."""
return self.total_tests > other.total_tests
class TestFolder:
"""Class to hold a folder with test files and folders."""
def __init__(self, path: Path) -> None:
"""Initialize test folder."""
self.path: Final = path
self.children: dict[Path, TestFolder | TestFile] = {}
@property
def total_tests(self) -> int:
"""Return total tests."""
return sum([test.total_tests for test in self.children.values()])
@property
def added_to_bucket(self) -> bool:
"""Return if added to bucket."""
return all(test.added_to_bucket for test in self.children.values())
def add_to_bucket(self) -> None:
"""Add test file to bucket."""
if self.added_to_bucket:
raise ValueError("Already added to bucket")
for child in self.children.values():
child.add_to_bucket()
def __repr__(self) -> str:
"""Return representation."""
return (
f"TestFolder(total_tests={self.total_tests}, children={len(self.children)})"
)
def add_test_file(self, file: TestFile) -> None:
"""Add test file to folder."""
path = file.path
file.parent = self
relative_path = path.relative_to(self.path)
if not relative_path.parts:
raise ValueError("Path is not a child of this folder")
if len(relative_path.parts) == 1:
self.children[path] = file
return
child_path = self.path / relative_path.parts[0]
if (child := self.children.get(child_path)) is None:
self.children[child_path] = child = TestFolder(child_path)
elif not isinstance(child, TestFolder):
raise ValueError("Child is not a folder")
child.add_test_file(file)
def get_all_flatten(self) -> list[TestFolder | TestFile]:
"""Return self and all children as flatten list."""
result: list[TestFolder | TestFile] = [self]
for child in self.children.values():
if isinstance(child, TestFolder):
result.extend(child.get_all_flatten())
else:
result.append(child)
return result
def _collect_batch(paths: list[Path]) -> tuple[str, str, int]:
"""Run pytest --collect-only on a batch of paths."""
result = subprocess.run(
["pytest", "--collect-only", "-qq", "-p", "no:warnings", *map(str, paths)],
check=False,
capture_output=True,
text=True,
)
return result.stdout, result.stderr, result.returncode
def _iter_eligible_children(path: Path) -> list[Path]:
"""Return immediate children of ``path`` that pytest should collect.
Filters out hidden/dunder entries, non-``test_*.py`` files (so helper
modules like ``conftest.py`` and ``common.py`` are not passed as
explicit collection targets), and pycache-style directories.
"""
children: list[Path] = []
for entry in sorted(path.iterdir()):
if entry.name.startswith((".", "_")):
continue
if entry.is_dir() or (entry.suffix == ".py" and entry.name.startswith("test_")):
children.append(entry)
return children
def _enumerate_batch_paths(path: Path) -> list[Path]:
"""Return the child paths to run pytest --collect-only over.
Files are returned as-is. Directories are expanded one level deep, with
a second level of expansion for entries named in ``_FAN_OUT_DIRS`` so the
enormous ``tests/components`` tree fans out into per-integration paths.
"""
if path.is_file():
return [path]
paths: list[Path] = []
for entry in _iter_eligible_children(path):
if entry.is_dir() and entry.name in _FAN_OUT_DIRS:
paths.extend(_iter_eligible_children(entry))
else:
paths.append(entry)
return paths
def _hash_file(path: Path) -> str:
"""Return a short content hash for ``path``."""
return hashlib.sha256(path.read_bytes()).hexdigest()[:16]
def _walk_test_tree(root: Path) -> tuple[list[Path], list[Path]]:
"""Walk ``root`` once and return (test files, conftest files).
Uses ``os.walk`` rather than ``Path.rglob`` because it's ~2x faster on
a 5000-file tree, and we prune hidden/dunder subdirectories instead of
visiting them. Doing both walks in one pass keeps total tree I/O down.
"""
if root.is_file():
if root.name.startswith("test_") and root.suffix == ".py":
return [root], []
return [], []
test_files: list[Path] = []
conftests: list[Path] = []
for dirpath, dirnames, filenames in os.walk(root):
dirnames[:] = [d for d in dirnames if not d.startswith((".", "_"))]
base = Path(dirpath)
for name in filenames:
if name == "conftest.py":
conftests.append(base / name)
elif name.startswith("test_") and name.endswith(".py"):
test_files.append(base / name)
test_files.sort()
conftests.sort()
return test_files, conftests
def _compute_conftest_hash(root: Path, conftests: list[Path]) -> str:
"""Return a hash that changes whenever any conftest.py under ``root`` changes.
Any change to a conftest invalidates the entire test-count cache. This is
coarse but safe: conftests can change fixture parametrization in ways the
cache cannot otherwise detect, so we just re-collect everything.
"""
digest = hashlib.sha256()
for conftest in conftests:
digest.update(str(conftest.relative_to(root)).encode())
digest.update(b"\0")
digest.update(conftest.read_bytes())
digest.update(b"\0")
return digest.hexdigest()
@dataclass
class _CacheEntry:
"""Cached test count for a single file."""
hash: str
count: int
@dataclass
class _Cache:
"""Mapping of test file path → cached entry, plus invalidation key."""
conftest_hash: str
entries: dict[str, _CacheEntry]
@classmethod
def empty(cls, conftest_hash: str = "") -> _Cache:
"""Return a new empty cache."""
return cls(conftest_hash=conftest_hash, entries={})
@classmethod
def load(cls, path: Path, current_conftest_hash: str) -> _Cache:
"""Load cache from ``path`` and invalidate it on schema/conftest drift.
Any failure (missing file, bad JSON, version drift, conftest drift)
returns an empty cache so the script just falls back to a full
collection. This is the self-healing path.
"""
try:
raw = json.loads(path.read_bytes())
except OSError, ValueError:
return cls.empty(current_conftest_hash)
if not isinstance(raw, dict) or raw.get("version") != _CACHE_VERSION:
return cls.empty(current_conftest_hash)
if raw.get("conftest_hash") != current_conftest_hash:
return cls.empty(current_conftest_hash)
files = raw.get("files")
if not isinstance(files, dict):
return cls.empty(current_conftest_hash)
entries: dict[str, _CacheEntry] = {}
for key, value in files.items():
if (
not isinstance(value, dict)
or not isinstance(value.get("hash"), str)
or not isinstance(value.get("count"), int)
):
# Skip malformed entries instead of discarding the whole cache.
continue
entries[key] = _CacheEntry(hash=value["hash"], count=value["count"])
return cls(conftest_hash=current_conftest_hash, entries=entries)
def save(self, path: Path) -> None:
"""Write the cache to ``path``."""
path.write_text(
json.dumps(
{
"version": _CACHE_VERSION,
"conftest_hash": self.conftest_hash,
"files": {
key: {"hash": entry.hash, "count": entry.count}
for key, entry in sorted(self.entries.items())
},
},
indent=2,
)
+ "\n"
)
def _resolve_from_cache(
test_files: list[Path],
cache: _Cache,
root: Path,
) -> tuple[dict[Path, int], list[Path]]:
"""Split ``test_files`` into ``(cached_counts, needs_collection)``.
A file is served from cache when its content hash matches what we
previously stored; otherwise it is queued for re-collection.
"""
cached: dict[Path, int] = {}
misses: list[Path] = []
for file in test_files:
key = str(file.relative_to(root))
entry = cache.entries.get(key)
if entry is None:
misses.append(file)
continue
if entry.hash != _hash_file(file):
misses.append(file)
continue
cached[file] = entry.count
return cached, misses
def _run_collect_batches(paths: list[Path]) -> list[tuple[str, str, int]]:
"""Run pytest --collect-only across ``paths`` using a process pool."""
workers = min(len(paths), os.cpu_count() or 1) or 1
batches = [paths[i::workers] for i in range(workers)]
if workers == 1:
return [_collect_batch(batches[0])]
with ProcessPoolExecutor(max_workers=workers) as executor:
return list(executor.map(_collect_batch, batches))
def _parse_collect_output(stdout: str) -> dict[Path, int]:
"""Parse ``pytest --collect-only -qq`` output into ``{path: count}``."""
counts: dict[Path, int] = {}
for line in stdout.splitlines():
if not line.strip():
continue
file_path, _, total_tests = line.partition(": ")
if not file_path or not total_tests:
raise ValueError(f"Unexpected line: {line}")
counts[Path(file_path)] = int(total_tests)
return counts
def collect_tests(path: Path, cache_path: Path | None = None) -> TestFolder:
"""Collect all tests, using an on-disk cache when available."""
all_test_files, conftests = _walk_test_tree(path)
conftest_hash = _compute_conftest_hash(path, conftests)
cache = (
_Cache.load(cache_path, conftest_hash)
if cache_path is not None
else _Cache.empty(conftest_hash)
)
if not all_test_files:
print(f"No eligible test paths found under {path}")
sys.exit(1)
cached_counts, missing = _resolve_from_cache(all_test_files, cache, path)
print(
f"Cache: {len(cached_counts)} hits / {len(missing)} misses"
f" / {len(all_test_files)} total"
)
new_counts: dict[Path, int] = {}
if missing:
# On a full cold-cache run, hand pytest the top-level directories
# instead of 5000+ individual file paths: pytest walks dirs much
# faster than it resolves each file argument. Once any cache hits
# exist, use file-level collection so we only re-collect the diff.
if not cached_counts:
collect_paths = _enumerate_batch_paths(path)
else:
collect_paths = missing
results = _run_collect_batches(collect_paths)
for stdout, stderr, returncode in results:
if returncode != 0:
print("Failed to collect tests:")
print(stderr)
print(stdout)
sys.exit(1)
try:
new_counts.update(_parse_collect_output(stdout))
except ValueError as err:
print(err)
sys.exit(1)
counts: dict[Path, int] = {**cached_counts, **new_counts}
folder = TestFolder(path)
for file_path, total_tests in counts.items():
if total_tests == 0:
# Files with no collected tests (eg helper modules named
# test_init.py with no test functions) shouldn't enter
# bucketing, but we still cache them below as count=0 so
# they don't get re-collected next run.
continue
folder.add_test_file(TestFile(total_tests, file_path))
if cache_path is not None:
# Rebuild the cache from scratch on every run so deleted files are
# dropped and re-collected files get a refreshed hash.
missing_set = set(missing)
updated_entries: dict[str, _CacheEntry] = {}
for file in all_test_files:
if file in counts:
count = counts[file]
elif file in missing_set:
# We asked pytest about this file and got no count back,
# so it has no collectible tests; cache it as 0 to avoid
# repeating the work next run.
count = 0
else:
continue
updated_entries[str(file.relative_to(path))] = _CacheEntry(
hash=_hash_file(file), count=count
)
_Cache(conftest_hash=conftest_hash, entries=updated_entries).save(cache_path)
return folder
def main() -> None:
"""Execute script."""
parser = argparse.ArgumentParser(description="Split tests into n buckets.")
def check_greater_0(value: str) -> int:
ivalue = int(value)
if ivalue <= 0:
raise argparse.ArgumentTypeError(
f"{value} is an invalid. Must be greater than 0"
)
return ivalue
parser.add_argument(
"bucket_count",
help="Number of buckets to split tests into",
type=check_greater_0,
)
parser.add_argument(
"path",
help="Path to the test files to split into buckets",
type=Path,
)
parser.add_argument(
"--cache",
help="Path to a JSON file used to cache per-file test counts",
type=Path,
default=None,
)
arguments = parser.parse_args()
print("Collecting tests...")
tests = collect_tests(arguments.path, arguments.cache)
tests_per_bucket = ceil(tests.total_tests / arguments.bucket_count)
bucket_holder = BucketHolder(tests_per_bucket, arguments.bucket_count)
print("Splitting tests...")
bucket_holder.split_tests(tests)
print(f"Total tests: {tests.total_tests}")
print(f"Estimated tests per bucket: {tests_per_bucket}")
bucket_holder.create_ouput_file()
if __name__ == "__main__":
main()