mirror of
https://github.com/home-assistant/core.git
synced 2026-06-02 05:34:15 +01:00
9ed16b63a3
Persist the result of pytest --collect-only between CI runs as a JSON file keyed by content hash, so unchanged test files are served from cache and only edited or new files are re-collected. The cache is self-healing: * Missing, corrupt, or wrong-version files fall back to a full collect. * Any conftest.py change anywhere under the test root invalidates the whole cache, so fixture parametrization shifts cannot silently skew counts. * Files pytest returns nothing for (helper modules named test_*.py with no test functions) are cached as zero so they don't get re-collected forever. Walking is done once with os.walk (~2x faster than Path.rglob) and collects test files plus conftests in a single pass. When the cache is fully cold we feed pytest top-level directories rather than thousands of file paths so cold runs stay as fast as before the cache landed. Wire the new --cache flag through the prepare-pytest-full job and back the cache file with actions/cache so PRs can pick up the latest dev snapshot via restore-keys. Local timings: cold 11s, warm with no diff 0.4s, warm with one file edited 2.3s.
513 lines
18 KiB
Python
Executable File
513 lines
18 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Helper script to split test into n buckets."""
|
|
|
|
import argparse
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
from dataclasses import dataclass, field
|
|
import hashlib
|
|
import json
|
|
from math import ceil
|
|
import os
|
|
from pathlib import Path
|
|
import subprocess
|
|
import sys
|
|
from typing import Final
|
|
|
|
# tests/components has ~1000 sub-directories, which makes it the natural
|
|
# place to subdivide to keep each pytest invocation roughly equal in size.
|
|
_FAN_OUT_DIRS: Final = frozenset({"components"})
|
|
|
|
# Cache file format version; bump on any incompatible schema change so old
|
|
# caches are ignored rather than misread.
|
|
_CACHE_VERSION: Final = 1
|
|
|
|
|
|
class Bucket:
|
|
"""Class to hold bucket."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize bucket."""
|
|
self.total_tests = 0
|
|
self._paths: list[str] = []
|
|
|
|
def add(self, part: TestFolder | TestFile) -> None:
|
|
"""Add tests to bucket."""
|
|
part.add_to_bucket()
|
|
self.total_tests += part.total_tests
|
|
self._paths.append(str(part.path))
|
|
|
|
def get_paths_line(self) -> str:
|
|
"""Return paths."""
|
|
return " ".join(self._paths) + "\n"
|
|
|
|
|
|
class BucketHolder:
|
|
"""Class to hold buckets."""
|
|
|
|
def __init__(self, tests_per_bucket: int, bucket_count: int) -> None:
|
|
"""Initialize bucket holder."""
|
|
self._tests_per_bucket = tests_per_bucket
|
|
self._bucket_count = bucket_count
|
|
self._buckets: list[Bucket] = [Bucket() for _ in range(bucket_count)]
|
|
|
|
def split_tests(self, test_folder: TestFolder) -> None:
|
|
"""Split tests into buckets."""
|
|
digits = len(str(test_folder.total_tests))
|
|
sorted_tests = sorted(
|
|
test_folder.get_all_flatten(), reverse=True, key=lambda x: x.total_tests
|
|
)
|
|
for tests in sorted_tests:
|
|
if tests.added_to_bucket:
|
|
# Already added to bucket
|
|
continue
|
|
|
|
print(f"{tests.total_tests:>{digits}} tests in {tests.path}")
|
|
smallest_bucket = min(self._buckets, key=lambda x: x.total_tests)
|
|
is_file = isinstance(tests, TestFile)
|
|
if (
|
|
smallest_bucket.total_tests + tests.total_tests < self._tests_per_bucket
|
|
) or is_file:
|
|
smallest_bucket.add(tests)
|
|
# Ensure all files from the same folder are in the same bucket
|
|
# to ensure that syrupy correctly identifies unused snapshots
|
|
if is_file:
|
|
for other_test in tests.parent.children.values():
|
|
if other_test is tests or isinstance(other_test, TestFolder):
|
|
continue
|
|
print(
|
|
f"{other_test.total_tests:>{digits}}"
|
|
f" tests in {other_test.path}"
|
|
" (same bucket)"
|
|
)
|
|
smallest_bucket.add(other_test)
|
|
|
|
# verify that all tests are added to a bucket
|
|
if not test_folder.added_to_bucket:
|
|
raise ValueError("Not all tests are added to a bucket")
|
|
|
|
def create_ouput_file(self) -> None:
|
|
"""Create output file."""
|
|
with Path("pytest_buckets.txt").open("w", encoding="utf-8") as file:
|
|
for idx, bucket in enumerate(self._buckets):
|
|
print(f"Bucket {idx + 1} has {bucket.total_tests} tests")
|
|
file.write(bucket.get_paths_line())
|
|
|
|
|
|
@dataclass
|
|
class TestFile:
|
|
"""Class represents a single test file and the number of tests it has."""
|
|
|
|
total_tests: int
|
|
path: Path
|
|
added_to_bucket: bool = field(default=False, init=False)
|
|
parent: TestFolder | None = field(default=None, init=False)
|
|
|
|
def add_to_bucket(self) -> None:
|
|
"""Add test file to bucket."""
|
|
if self.added_to_bucket:
|
|
raise ValueError("Already added to bucket")
|
|
self.added_to_bucket = True
|
|
|
|
def __gt__(self, other: TestFile) -> bool:
|
|
"""Return if greater than."""
|
|
return self.total_tests > other.total_tests
|
|
|
|
|
|
class TestFolder:
|
|
"""Class to hold a folder with test files and folders."""
|
|
|
|
def __init__(self, path: Path) -> None:
|
|
"""Initialize test folder."""
|
|
self.path: Final = path
|
|
self.children: dict[Path, TestFolder | TestFile] = {}
|
|
|
|
@property
|
|
def total_tests(self) -> int:
|
|
"""Return total tests."""
|
|
return sum([test.total_tests for test in self.children.values()])
|
|
|
|
@property
|
|
def added_to_bucket(self) -> bool:
|
|
"""Return if added to bucket."""
|
|
return all(test.added_to_bucket for test in self.children.values())
|
|
|
|
def add_to_bucket(self) -> None:
|
|
"""Add test file to bucket."""
|
|
if self.added_to_bucket:
|
|
raise ValueError("Already added to bucket")
|
|
for child in self.children.values():
|
|
child.add_to_bucket()
|
|
|
|
def __repr__(self) -> str:
|
|
"""Return representation."""
|
|
return (
|
|
f"TestFolder(total_tests={self.total_tests}, children={len(self.children)})"
|
|
)
|
|
|
|
def add_test_file(self, file: TestFile) -> None:
|
|
"""Add test file to folder."""
|
|
path = file.path
|
|
file.parent = self
|
|
relative_path = path.relative_to(self.path)
|
|
if not relative_path.parts:
|
|
raise ValueError("Path is not a child of this folder")
|
|
|
|
if len(relative_path.parts) == 1:
|
|
self.children[path] = file
|
|
return
|
|
|
|
child_path = self.path / relative_path.parts[0]
|
|
if (child := self.children.get(child_path)) is None:
|
|
self.children[child_path] = child = TestFolder(child_path)
|
|
elif not isinstance(child, TestFolder):
|
|
raise ValueError("Child is not a folder")
|
|
child.add_test_file(file)
|
|
|
|
def get_all_flatten(self) -> list[TestFolder | TestFile]:
|
|
"""Return self and all children as flatten list."""
|
|
result: list[TestFolder | TestFile] = [self]
|
|
for child in self.children.values():
|
|
if isinstance(child, TestFolder):
|
|
result.extend(child.get_all_flatten())
|
|
else:
|
|
result.append(child)
|
|
return result
|
|
|
|
|
|
def _collect_batch(paths: list[Path]) -> tuple[str, str, int]:
|
|
"""Run pytest --collect-only on a batch of paths."""
|
|
result = subprocess.run(
|
|
["pytest", "--collect-only", "-qq", "-p", "no:warnings", *map(str, paths)],
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
return result.stdout, result.stderr, result.returncode
|
|
|
|
|
|
def _iter_eligible_children(path: Path) -> list[Path]:
|
|
"""Return immediate children of ``path`` that pytest should collect.
|
|
|
|
Filters out hidden/dunder entries, non-``test_*.py`` files (so helper
|
|
modules like ``conftest.py`` and ``common.py`` are not passed as
|
|
explicit collection targets), and pycache-style directories.
|
|
"""
|
|
children: list[Path] = []
|
|
for entry in sorted(path.iterdir()):
|
|
if entry.name.startswith((".", "_")):
|
|
continue
|
|
if entry.is_dir() or (entry.suffix == ".py" and entry.name.startswith("test_")):
|
|
children.append(entry)
|
|
return children
|
|
|
|
|
|
def _enumerate_batch_paths(path: Path) -> list[Path]:
|
|
"""Return the child paths to run pytest --collect-only over.
|
|
|
|
Files are returned as-is. Directories are expanded one level deep, with
|
|
a second level of expansion for entries named in ``_FAN_OUT_DIRS`` so the
|
|
enormous ``tests/components`` tree fans out into per-integration paths.
|
|
"""
|
|
if path.is_file():
|
|
return [path]
|
|
|
|
paths: list[Path] = []
|
|
for entry in _iter_eligible_children(path):
|
|
if entry.is_dir() and entry.name in _FAN_OUT_DIRS:
|
|
paths.extend(_iter_eligible_children(entry))
|
|
else:
|
|
paths.append(entry)
|
|
return paths
|
|
|
|
|
|
def _hash_file(path: Path) -> str:
|
|
"""Return a short content hash for ``path``."""
|
|
return hashlib.sha256(path.read_bytes()).hexdigest()[:16]
|
|
|
|
|
|
def _walk_test_tree(root: Path) -> tuple[list[Path], list[Path]]:
|
|
"""Walk ``root`` once and return (test files, conftest files).
|
|
|
|
Uses ``os.walk`` rather than ``Path.rglob`` because it's ~2x faster on
|
|
a 5000-file tree, and we prune hidden/dunder subdirectories instead of
|
|
visiting them. Doing both walks in one pass keeps total tree I/O down.
|
|
"""
|
|
if root.is_file():
|
|
if root.name.startswith("test_") and root.suffix == ".py":
|
|
return [root], []
|
|
return [], []
|
|
|
|
test_files: list[Path] = []
|
|
conftests: list[Path] = []
|
|
for dirpath, dirnames, filenames in os.walk(root):
|
|
dirnames[:] = [d for d in dirnames if not d.startswith((".", "_"))]
|
|
base = Path(dirpath)
|
|
for name in filenames:
|
|
if name == "conftest.py":
|
|
conftests.append(base / name)
|
|
elif name.startswith("test_") and name.endswith(".py"):
|
|
test_files.append(base / name)
|
|
test_files.sort()
|
|
conftests.sort()
|
|
return test_files, conftests
|
|
|
|
|
|
def _compute_conftest_hash(root: Path, conftests: list[Path]) -> str:
|
|
"""Return a hash that changes whenever any conftest.py under ``root`` changes.
|
|
|
|
Any change to a conftest invalidates the entire test-count cache. This is
|
|
coarse but safe: conftests can change fixture parametrization in ways the
|
|
cache cannot otherwise detect, so we just re-collect everything.
|
|
"""
|
|
digest = hashlib.sha256()
|
|
for conftest in conftests:
|
|
digest.update(str(conftest.relative_to(root)).encode())
|
|
digest.update(b"\0")
|
|
digest.update(conftest.read_bytes())
|
|
digest.update(b"\0")
|
|
return digest.hexdigest()
|
|
|
|
|
|
@dataclass
|
|
class _CacheEntry:
|
|
"""Cached test count for a single file."""
|
|
|
|
hash: str
|
|
count: int
|
|
|
|
|
|
@dataclass
|
|
class _Cache:
|
|
"""Mapping of test file path → cached entry, plus invalidation key."""
|
|
|
|
conftest_hash: str
|
|
entries: dict[str, _CacheEntry]
|
|
|
|
@classmethod
|
|
def empty(cls, conftest_hash: str = "") -> _Cache:
|
|
"""Return a new empty cache."""
|
|
return cls(conftest_hash=conftest_hash, entries={})
|
|
|
|
@classmethod
|
|
def load(cls, path: Path, current_conftest_hash: str) -> _Cache:
|
|
"""Load cache from ``path`` and invalidate it on schema/conftest drift.
|
|
|
|
Any failure (missing file, bad JSON, version drift, conftest drift)
|
|
returns an empty cache so the script just falls back to a full
|
|
collection. This is the self-healing path.
|
|
"""
|
|
try:
|
|
raw = json.loads(path.read_bytes())
|
|
except OSError, ValueError:
|
|
return cls.empty(current_conftest_hash)
|
|
if not isinstance(raw, dict) or raw.get("version") != _CACHE_VERSION:
|
|
return cls.empty(current_conftest_hash)
|
|
if raw.get("conftest_hash") != current_conftest_hash:
|
|
return cls.empty(current_conftest_hash)
|
|
files = raw.get("files")
|
|
if not isinstance(files, dict):
|
|
return cls.empty(current_conftest_hash)
|
|
entries: dict[str, _CacheEntry] = {}
|
|
for key, value in files.items():
|
|
if (
|
|
not isinstance(value, dict)
|
|
or not isinstance(value.get("hash"), str)
|
|
or not isinstance(value.get("count"), int)
|
|
):
|
|
# Skip malformed entries instead of discarding the whole cache.
|
|
continue
|
|
entries[key] = _CacheEntry(hash=value["hash"], count=value["count"])
|
|
return cls(conftest_hash=current_conftest_hash, entries=entries)
|
|
|
|
def save(self, path: Path) -> None:
|
|
"""Write the cache to ``path``."""
|
|
path.write_text(
|
|
json.dumps(
|
|
{
|
|
"version": _CACHE_VERSION,
|
|
"conftest_hash": self.conftest_hash,
|
|
"files": {
|
|
key: {"hash": entry.hash, "count": entry.count}
|
|
for key, entry in sorted(self.entries.items())
|
|
},
|
|
},
|
|
indent=2,
|
|
)
|
|
+ "\n"
|
|
)
|
|
|
|
|
|
def _resolve_from_cache(
|
|
test_files: list[Path],
|
|
cache: _Cache,
|
|
root: Path,
|
|
) -> tuple[dict[Path, int], list[Path]]:
|
|
"""Split ``test_files`` into ``(cached_counts, needs_collection)``.
|
|
|
|
A file is served from cache when its content hash matches what we
|
|
previously stored; otherwise it is queued for re-collection.
|
|
"""
|
|
cached: dict[Path, int] = {}
|
|
misses: list[Path] = []
|
|
for file in test_files:
|
|
key = str(file.relative_to(root))
|
|
entry = cache.entries.get(key)
|
|
if entry is None:
|
|
misses.append(file)
|
|
continue
|
|
if entry.hash != _hash_file(file):
|
|
misses.append(file)
|
|
continue
|
|
cached[file] = entry.count
|
|
return cached, misses
|
|
|
|
|
|
def _run_collect_batches(paths: list[Path]) -> list[tuple[str, str, int]]:
|
|
"""Run pytest --collect-only across ``paths`` using a process pool."""
|
|
workers = min(len(paths), os.cpu_count() or 1) or 1
|
|
batches = [paths[i::workers] for i in range(workers)]
|
|
if workers == 1:
|
|
return [_collect_batch(batches[0])]
|
|
with ProcessPoolExecutor(max_workers=workers) as executor:
|
|
return list(executor.map(_collect_batch, batches))
|
|
|
|
|
|
def _parse_collect_output(stdout: str) -> dict[Path, int]:
|
|
"""Parse ``pytest --collect-only -qq`` output into ``{path: count}``."""
|
|
counts: dict[Path, int] = {}
|
|
for line in stdout.splitlines():
|
|
if not line.strip():
|
|
continue
|
|
file_path, _, total_tests = line.partition(": ")
|
|
if not file_path or not total_tests:
|
|
raise ValueError(f"Unexpected line: {line}")
|
|
counts[Path(file_path)] = int(total_tests)
|
|
return counts
|
|
|
|
|
|
def collect_tests(path: Path, cache_path: Path | None = None) -> TestFolder:
|
|
"""Collect all tests, using an on-disk cache when available."""
|
|
all_test_files, conftests = _walk_test_tree(path)
|
|
conftest_hash = _compute_conftest_hash(path, conftests)
|
|
cache = (
|
|
_Cache.load(cache_path, conftest_hash)
|
|
if cache_path is not None
|
|
else _Cache.empty(conftest_hash)
|
|
)
|
|
|
|
if not all_test_files:
|
|
print(f"No eligible test paths found under {path}")
|
|
sys.exit(1)
|
|
|
|
cached_counts, missing = _resolve_from_cache(all_test_files, cache, path)
|
|
print(
|
|
f"Cache: {len(cached_counts)} hits / {len(missing)} misses"
|
|
f" / {len(all_test_files)} total"
|
|
)
|
|
|
|
new_counts: dict[Path, int] = {}
|
|
if missing:
|
|
# On a full cold-cache run, hand pytest the top-level directories
|
|
# instead of 5000+ individual file paths: pytest walks dirs much
|
|
# faster than it resolves each file argument. Once any cache hits
|
|
# exist, use file-level collection so we only re-collect the diff.
|
|
if not cached_counts:
|
|
collect_paths = _enumerate_batch_paths(path)
|
|
else:
|
|
collect_paths = missing
|
|
results = _run_collect_batches(collect_paths)
|
|
for stdout, stderr, returncode in results:
|
|
if returncode != 0:
|
|
print("Failed to collect tests:")
|
|
print(stderr)
|
|
print(stdout)
|
|
sys.exit(1)
|
|
try:
|
|
new_counts.update(_parse_collect_output(stdout))
|
|
except ValueError as err:
|
|
print(err)
|
|
sys.exit(1)
|
|
|
|
counts: dict[Path, int] = {**cached_counts, **new_counts}
|
|
|
|
folder = TestFolder(path)
|
|
for file_path, total_tests in counts.items():
|
|
if total_tests == 0:
|
|
# Files with no collected tests (eg helper modules named
|
|
# test_init.py with no test functions) shouldn't enter
|
|
# bucketing, but we still cache them below as count=0 so
|
|
# they don't get re-collected next run.
|
|
continue
|
|
folder.add_test_file(TestFile(total_tests, file_path))
|
|
|
|
if cache_path is not None:
|
|
# Rebuild the cache from scratch on every run so deleted files are
|
|
# dropped and re-collected files get a refreshed hash.
|
|
missing_set = set(missing)
|
|
updated_entries: dict[str, _CacheEntry] = {}
|
|
for file in all_test_files:
|
|
if file in counts:
|
|
count = counts[file]
|
|
elif file in missing_set:
|
|
# We asked pytest about this file and got no count back,
|
|
# so it has no collectible tests; cache it as 0 to avoid
|
|
# repeating the work next run.
|
|
count = 0
|
|
else:
|
|
continue
|
|
updated_entries[str(file.relative_to(path))] = _CacheEntry(
|
|
hash=_hash_file(file), count=count
|
|
)
|
|
_Cache(conftest_hash=conftest_hash, entries=updated_entries).save(cache_path)
|
|
|
|
return folder
|
|
|
|
|
|
def main() -> None:
|
|
"""Execute script."""
|
|
parser = argparse.ArgumentParser(description="Split tests into n buckets.")
|
|
|
|
def check_greater_0(value: str) -> int:
|
|
ivalue = int(value)
|
|
if ivalue <= 0:
|
|
raise argparse.ArgumentTypeError(
|
|
f"{value} is an invalid. Must be greater than 0"
|
|
)
|
|
return ivalue
|
|
|
|
parser.add_argument(
|
|
"bucket_count",
|
|
help="Number of buckets to split tests into",
|
|
type=check_greater_0,
|
|
)
|
|
parser.add_argument(
|
|
"path",
|
|
help="Path to the test files to split into buckets",
|
|
type=Path,
|
|
)
|
|
parser.add_argument(
|
|
"--cache",
|
|
help="Path to a JSON file used to cache per-file test counts",
|
|
type=Path,
|
|
default=None,
|
|
)
|
|
|
|
arguments = parser.parse_args()
|
|
|
|
print("Collecting tests...")
|
|
tests = collect_tests(arguments.path, arguments.cache)
|
|
tests_per_bucket = ceil(tests.total_tests / arguments.bucket_count)
|
|
|
|
bucket_holder = BucketHolder(tests_per_bucket, arguments.bucket_count)
|
|
print("Splitting tests...")
|
|
bucket_holder.split_tests(tests)
|
|
|
|
print(f"Total tests: {tests.total_tests}")
|
|
print(f"Estimated tests per bucket: {tests_per_bucket}")
|
|
|
|
bucket_holder.create_ouput_file()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|