diff --git a/script/hassfest/quality_scale.py b/script/hassfest/quality_scale.py index 1d82ae0916a..becbeb2bd18 100644 --- a/script/hassfest/quality_scale.py +++ b/script/hassfest/quality_scale.py @@ -14,6 +14,7 @@ from homeassistant.util.yaml import load_yaml_dict from .model import Config, Integration, ScaledQualityScaleTiers from .quality_scale_validation import ( RuleValidationProtocol, + action_setup, config_entry_unloading, config_flow, diagnostics, @@ -41,7 +42,7 @@ class Rule: ALL_RULES = [ # BRONZE - Rule("action-setup", ScaledQualityScaleTiers.BRONZE), + Rule("action-setup", ScaledQualityScaleTiers.BRONZE, action_setup), Rule("appropriate-polling", ScaledQualityScaleTiers.BRONZE), Rule("brands", ScaledQualityScaleTiers.BRONZE), Rule("common-modules", ScaledQualityScaleTiers.BRONZE), diff --git a/script/hassfest/quality_scale_validation/action_setup.py b/script/hassfest/quality_scale_validation/action_setup.py new file mode 100644 index 00000000000..d8db6b9da51 --- /dev/null +++ b/script/hassfest/quality_scale_validation/action_setup.py @@ -0,0 +1,73 @@ +"""Enforce that the integration service actions are registered in async_setup. + +https://developers.home-assistant.io/docs/core/integration-quality-scale/rules/action-setup/ +""" + +import ast + +from script.hassfest import ast_parse_module +from script.hassfest.manifest import Platform +from script.hassfest.model import Config, Integration + + +def _get_setup_entry_function(module: ast.Module) -> ast.AsyncFunctionDef | None: + """Get async_setup_entry function.""" + for item in module.body: + if isinstance(item, ast.AsyncFunctionDef) and item.name == "async_setup_entry": + return item + return None + + +def _calls_service_registration( + async_setup_entry_function: ast.AsyncFunctionDef, +) -> bool: + """Check if there are calls to service registration.""" + for node in ast.walk(async_setup_entry_function): + if not (isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute)): + continue + + if node.func.attr == "async_register_entity_service": + return True + + if ( + isinstance(node.func.value, ast.Attribute) + and isinstance(node.func.value.value, ast.Name) + and node.func.value.value.id == "hass" + and node.func.value.attr == "services" + and node.func.attr in {"async_register", "register"} + ): + return True + + return False + + +def validate( + config: Config, integration: Integration, *, rules_done: set[str] +) -> list[str] | None: + """Validate that service actions are registered in async_setup.""" + + errors = [] + + module_file = integration.path / "__init__.py" + module = ast_parse_module(module_file) + if ( + async_setup_entry := _get_setup_entry_function(module) + ) and _calls_service_registration(async_setup_entry): + errors.append( + f"Integration registers services in {module_file} (async_setup_entry)" + ) + + for platform in Platform: + module_file = integration.path / f"{platform}.py" + if not module_file.exists(): + continue + module = ast_parse_module(module_file) + + if ( + async_setup_entry := _get_setup_entry_function(module) + ) and _calls_service_registration(async_setup_entry): + errors.append( + f"Integration registers services in {module_file} (async_setup_entry)" + ) + + return errors