From 2d33a720f76baf4bf011c64d164005ac1712fa6e Mon Sep 17 00:00:00 2001 From: Artur Pragacz <49985303+arturpragacz@users.noreply.github.com> Date: Tue, 16 Dec 2025 10:46:10 +0100 Subject: [PATCH] Modernise condition checker in helper (#159159) --- .../components/device_automation/condition.py | 21 ++++--- homeassistant/components/light/condition.py | 17 +++--- homeassistant/components/sun/condition.py | 15 +++-- homeassistant/components/zone/condition.py | 15 +++-- homeassistant/helpers/condition.py | 55 ++++++++++++++++--- homeassistant/helpers/script.py | 20 ++++--- tests/helpers/test_condition.py | 10 ++-- 7 files changed, 99 insertions(+), 54 deletions(-) diff --git a/homeassistant/components/device_automation/condition.py b/homeassistant/components/device_automation/condition.py index f9894f6658e..dde1ee7bfe0 100644 --- a/homeassistant/components/device_automation/condition.py +++ b/homeassistant/components/device_automation/condition.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol +from typing import Any, Protocol import voluptuous as vol @@ -11,18 +11,15 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv from homeassistant.helpers.condition import ( Condition, + ConditionChecker, ConditionCheckerType, ConditionConfig, - trace_condition_function, ) -from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.typing import ConfigType, TemplateVarsType from . import DeviceAutomationType, async_get_device_automation_platform from .helpers import async_validate_device_automation_config -if TYPE_CHECKING: - from homeassistant.helpers import condition - class DeviceAutomationConditionProtocol(Protocol): """Define the format of device_condition modules. @@ -90,15 +87,21 @@ class DeviceCondition(Condition): assert config.options is not None self._config = config.options - async def async_get_checker(self) -> condition.ConditionCheckerType: + async def async_get_checker(self) -> ConditionChecker: """Test a device condition.""" platform = await async_get_device_automation_platform( self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION ) - return trace_condition_function( - platform.async_condition_from_config(self._hass, self._config) + platform_checker = platform.async_condition_from_config( + self._hass, self._config ) + def checker(variables: TemplateVarsType = None, **kwargs: Any) -> bool: + result = platform_checker(self._hass, variables) + return result is not False + + return checker + CONDITIONS: dict[str, type[Condition]] = { "_device": DeviceCondition, diff --git a/homeassistant/components/light/condition.py b/homeassistant/components/light/condition.py index 1c1b178c002..139f9e71ebc 100644 --- a/homeassistant/components/light/condition.py +++ b/homeassistant/components/light/condition.py @@ -1,7 +1,7 @@ """Provides conditions for lights.""" from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Final, override +from typing import TYPE_CHECKING, Any, Final, Unpack, override import voluptuous as vol @@ -10,11 +10,11 @@ from homeassistant.core import HomeAssistant, split_entity_id from homeassistant.helpers import config_validation as cv, target from homeassistant.helpers.condition import ( Condition, - ConditionCheckerType, + ConditionChecker, + ConditionCheckParams, ConditionConfig, - trace_condition_function, ) -from homeassistant.helpers.typing import ConfigType, TemplateVarsType +from homeassistant.helpers.typing import ConfigType from .const import DOMAIN @@ -61,7 +61,7 @@ class StateConditionBase(Condition): self._state = state @override - async def async_get_checker(self) -> ConditionCheckerType: + async def async_get_checker(self) -> ConditionChecker: """Get the condition checker.""" def check_any_match_state(states: list[str]) -> bool: @@ -78,12 +78,11 @@ class StateConditionBase(Condition): elif self._behavior == BEHAVIOR_ALL: matcher = check_all_match_state - @trace_condition_function - def test_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: + def test_state(**kwargs: Unpack[ConditionCheckParams]) -> bool: """Test state condition.""" target_selection = target.TargetSelection(self._target) targeted_entities = target.async_extract_referenced_entity_ids( - hass, target_selection, expand_group=False + self._hass, target_selection, expand_group=False ) referenced_entity_ids = targeted_entities.referenced.union( targeted_entities.indirectly_referenced @@ -96,7 +95,7 @@ class StateConditionBase(Condition): light_entity_states = [ state.state for entity_id in light_entity_ids - if (state := hass.states.get(entity_id)) + if (state := self._hass.states.get(entity_id)) and state.state in STATE_CONDITION_VALID_STATES ] return matcher(light_entity_states) diff --git a/homeassistant/components/sun/condition.py b/homeassistant/components/sun/condition.py index 1a4a9a4c6db..9d4dc0764ee 100644 --- a/homeassistant/components/sun/condition.py +++ b/homeassistant/components/sun/condition.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import datetime, timedelta -from typing import Any, cast +from typing import Any, Unpack, cast import voluptuous as vol @@ -13,14 +13,14 @@ from homeassistant.helpers import config_validation as cv from homeassistant.helpers.automation import move_top_level_schema_fields_to_options from homeassistant.helpers.condition import ( Condition, - ConditionCheckerType, + ConditionChecker, + ConditionCheckParams, ConditionConfig, condition_trace_set_result, condition_trace_update_result, - trace_condition_function, ) from homeassistant.helpers.sun import get_astral_event_date -from homeassistant.helpers.typing import ConfigType, TemplateVarsType +from homeassistant.helpers.typing import ConfigType from homeassistant.util import dt as dt_util _OPTIONS_SCHEMA_DICT: dict[vol.Marker, Any] = { @@ -154,17 +154,16 @@ class SunCondition(Condition): assert config.options is not None self._options = config.options - async def async_get_checker(self) -> ConditionCheckerType: + async def async_get_checker(self) -> ConditionChecker: """Wrap action method with sun based condition.""" before = self._options.get("before") after = self._options.get("after") before_offset = self._options.get("before_offset") after_offset = self._options.get("after_offset") - @trace_condition_function - def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: + def sun_if(**kwargs: Unpack[ConditionCheckParams]) -> bool: """Validate time based if-condition.""" - return sun(hass, before, after, before_offset, after_offset) + return sun(self._hass, before, after, before_offset, after_offset) return sun_if diff --git a/homeassistant/components/zone/condition.py b/homeassistant/components/zone/condition.py index d106ea092a8..ee3f286c660 100644 --- a/homeassistant/components/zone/condition.py +++ b/homeassistant/components/zone/condition.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, cast +from typing import Any, Unpack, cast import voluptuous as vol @@ -22,11 +22,11 @@ from homeassistant.helpers import config_validation as cv from homeassistant.helpers.automation import move_top_level_schema_fields_to_options from homeassistant.helpers.condition import ( Condition, - ConditionCheckerType, + ConditionChecker, + ConditionCheckParams, ConditionConfig, - trace_condition_function, ) -from homeassistant.helpers.typing import ConfigType, TemplateVarsType +from homeassistant.helpers.typing import ConfigType from . import in_zone @@ -118,13 +118,12 @@ class ZoneCondition(Condition): assert config.options is not None self._options = config.options - async def async_get_checker(self) -> ConditionCheckerType: + async def async_get_checker(self) -> ConditionChecker: """Wrap action method with zone based condition.""" entity_ids = self._options.get(CONF_ENTITY_ID, []) zone_entity_ids = self._options.get(CONF_ZONE, []) - @trace_condition_function - def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: + def if_in_zone(**kwargs: Unpack[ConditionCheckParams]) -> bool: """Test if condition.""" errors = [] @@ -133,7 +132,7 @@ class ZoneCondition(Condition): entity_ok = False for zone_entity_id in zone_entity_ids: try: - if zone(hass, zone_entity_id, entity_id): + if zone(self._hass, zone_entity_id, entity_id): entity_ok = True except ConditionErrorMessage as ex: errors.append( diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index b99079822d8..957ff25434f 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -13,7 +13,7 @@ import inspect import logging import re import sys -from typing import TYPE_CHECKING, Any, Protocol, cast +from typing import TYPE_CHECKING, Any, Protocol, TypedDict, Unpack, cast, overload import voluptuous as vol @@ -298,7 +298,7 @@ class Condition(abc.ABC): self._hass = hass @abc.abstractmethod - async def async_get_checker(self) -> ConditionCheckerType: + async def async_get_checker(self) -> ConditionChecker: """Get the condition checker.""" @@ -319,7 +319,23 @@ class ConditionConfig: target: dict[str, Any] | None = None -type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None] +class ConditionCheckParams(TypedDict, total=False): + """Condition check params.""" + + variables: TemplateVarsType + + +class ConditionChecker(Protocol): + """Protocol for condition checker callable with typed kwargs.""" + + def __call__(self, **kwargs: Unpack[ConditionCheckParams]) -> bool: + """Check the condition.""" + + +type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool] +type ConditionCheckerTypeOptional = Callable[ + [HomeAssistant, TemplateVarsType], bool | None +] def condition_trace_append(variables: TemplateVarsType, path: str) -> TraceElement: @@ -374,7 +390,21 @@ def trace_condition(variables: TemplateVarsType) -> Generator[TraceElement]: trace_stack_pop(trace_stack_cv) -def trace_condition_function(condition: ConditionCheckerType) -> ConditionCheckerType: +@overload +def trace_condition_function( + condition: ConditionCheckerType, +) -> ConditionCheckerType: ... + + +@overload +def trace_condition_function( + condition: ConditionCheckerTypeOptional, +) -> ConditionCheckerTypeOptional: ... + + +def trace_condition_function( + condition: ConditionCheckerType | ConditionCheckerTypeOptional, +) -> ConditionCheckerType | ConditionCheckerTypeOptional: """Wrap a condition function to enable basic tracing.""" @ft.wraps(condition) @@ -420,10 +450,20 @@ async def _async_get_condition_platform( ) from None +async def _async_get_checker(condition: Condition) -> ConditionCheckerType: + new_checker = await condition.async_get_checker() + + @trace_condition_function + def checker(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: + return new_checker(variables=variables) + + return checker + + async def async_from_config( hass: HomeAssistant, config: ConfigType, -) -> ConditionCheckerType: +) -> ConditionCheckerTypeOptional: """Turn a condition configuration into a method. Should be run on the event loop. @@ -466,7 +506,7 @@ async def async_from_config( target=config.get(CONF_TARGET), ), ) - return await condition.async_get_checker() + return await _async_get_checker(condition) for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT): factory = getattr(sys.modules[__name__], fmt.format(condition_key), None) @@ -1131,7 +1171,7 @@ async def async_conditions_from_config( name: str, ) -> Callable[[TemplateVarsType], bool]: """AND all conditions.""" - checks: list[ConditionCheckerType] = [ + checks = [ await async_from_config(hass, condition_config) for condition_config in condition_configs ] @@ -1330,7 +1370,6 @@ async def async_get_all_descriptions( continue description = {"fields": yaml_description.get("fields", {})} - if (target := yaml_description.get("target")) is not None: description["target"] = target diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 937968e9742..3d7b99d571c 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -86,7 +86,7 @@ from homeassistant.util.hass_dict import HassKey from homeassistant.util.signal_type import SignalType, SignalTypeFormat from . import condition, config_validation as cv, service, template -from .condition import ConditionCheckerType, trace_condition_function +from .condition import ConditionCheckerTypeOptional, trace_condition_function from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal from .event import async_call_later, async_track_template from .script_variables import ScriptRunVariables, ScriptVariables @@ -675,12 +675,14 @@ class _ScriptRun: ### Condition actions ### - async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType: + async def _async_get_condition( + self, config: ConfigType + ) -> ConditionCheckerTypeOptional: return await self._script._async_get_condition(config) # noqa: SLF001 def _test_conditions( self, - conditions: list[ConditionCheckerType], + conditions: list[ConditionCheckerTypeOptional], name: str, condition_path: str | None = None, ) -> bool | None: @@ -1404,12 +1406,12 @@ def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None: class _ChooseData(TypedDict): - choices: list[tuple[list[ConditionCheckerType], Script]] + choices: list[tuple[list[ConditionCheckerTypeOptional], Script]] default: Script | None class _IfData(TypedDict): - if_conditions: list[ConditionCheckerType] + if_conditions: list[ConditionCheckerTypeOptional] if_then: Script if_else: Script | None @@ -1486,7 +1488,9 @@ class Script: self._max_exceeded = max_exceeded if script_mode == SCRIPT_MODE_QUEUED: self._queue_lck = asyncio.Lock() - self._config_cache: dict[frozenset[tuple[str, str]], ConditionCheckerType] = {} + self._config_cache: dict[ + frozenset[tuple[str, str]], ConditionCheckerTypeOptional + ] = {} self._repeat_script: dict[int, Script] = {} self._choose_data: dict[int, _ChooseData] = {} self._if_data: dict[int, _IfData] = {} @@ -1857,7 +1861,9 @@ class Script: return await asyncio.shield(create_eager_task(self._async_stop(aws, update_state))) - async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType: + async def _async_get_condition( + self, config: ConfigType + ) -> ConditionCheckerTypeOptional: config_cache_key = frozenset((k, str(v)) for k, v in config.items()) if not (cond := self._config_cache.get(config_cache_key)): cond = await condition.async_from_config(self._hass, config) diff --git a/tests/helpers/test_condition.py b/tests/helpers/test_condition.py index 830154f9c0a..92702b6f1a3 100644 --- a/tests/helpers/test_condition.py +++ b/tests/helpers/test_condition.py @@ -36,7 +36,7 @@ from homeassistant.helpers import ( from homeassistant.helpers.automation import move_top_level_schema_fields_to_options from homeassistant.helpers.condition import ( Condition, - ConditionCheckerType, + ConditionChecker, async_validate_condition_config, ) from homeassistant.helpers.template import Template @@ -2126,16 +2126,16 @@ async def test_platform_multiple_conditions(hass: HomeAssistant) -> None: class MockCondition1(MockCondition): """Mock condition 1.""" - async def async_get_checker(self) -> ConditionCheckerType: + async def async_get_checker(self) -> ConditionChecker: """Evaluate state based on configuration.""" - return lambda hass, vars: True + return lambda **kwargs: True class MockCondition2(MockCondition): """Mock condition 2.""" - async def async_get_checker(self) -> ConditionCheckerType: + async def async_get_checker(self) -> ConditionChecker: """Evaluate state based on configuration.""" - return lambda hass, vars: False + return lambda **kwargs: False async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]: return {