From ccecbcb389592f0852e087c3dbfead356c9db359 Mon Sep 17 00:00:00 2001 From: Ariel Ebersberger <31776703+justanotherariel@users.noreply.github.com> Date: Mon, 16 Mar 2026 16:57:53 +0100 Subject: [PATCH] Refactor condition helpers (#165662) --- homeassistant/components/climate/condition.py | 19 +++--- .../components/humidifier/condition.py | 15 ++--- homeassistant/helpers/condition.py | 63 ++++++++----------- 3 files changed, 41 insertions(+), 56 deletions(-) diff --git a/homeassistant/components/climate/condition.py b/homeassistant/components/climate/condition.py index e1cee4ede99..8535890bd5e 100644 --- a/homeassistant/components/climate/condition.py +++ b/homeassistant/components/climate/condition.py @@ -1,11 +1,8 @@ """Provides conditions for climates.""" from homeassistant.core import HomeAssistant -from homeassistant.helpers.condition import ( - Condition, - make_entity_state_attribute_condition, - make_entity_state_condition, -) +from homeassistant.helpers.automation import DomainSpec +from homeassistant.helpers.condition import Condition, make_entity_state_condition from .const import ATTR_HVAC_ACTION, DOMAIN, HVACAction, HVACMode @@ -22,14 +19,14 @@ CONDITIONS: dict[str, type[Condition]] = { HVACMode.HEAT_COOL, }, ), - "is_cooling": make_entity_state_attribute_condition( - DOMAIN, ATTR_HVAC_ACTION, HVACAction.COOLING + "is_cooling": make_entity_state_condition( + {DOMAIN: DomainSpec(value_source=ATTR_HVAC_ACTION)}, HVACAction.COOLING ), - "is_drying": make_entity_state_attribute_condition( - DOMAIN, ATTR_HVAC_ACTION, HVACAction.DRYING + "is_drying": make_entity_state_condition( + {DOMAIN: DomainSpec(value_source=ATTR_HVAC_ACTION)}, HVACAction.DRYING ), - "is_heating": make_entity_state_attribute_condition( - DOMAIN, ATTR_HVAC_ACTION, HVACAction.HEATING + "is_heating": make_entity_state_condition( + {DOMAIN: DomainSpec(value_source=ATTR_HVAC_ACTION)}, HVACAction.HEATING ), } diff --git a/homeassistant/components/humidifier/condition.py b/homeassistant/components/humidifier/condition.py index 77c108128a2..f29100ae402 100644 --- a/homeassistant/components/humidifier/condition.py +++ b/homeassistant/components/humidifier/condition.py @@ -2,22 +2,19 @@ from homeassistant.const import STATE_OFF, STATE_ON from homeassistant.core import HomeAssistant -from homeassistant.helpers.condition import ( - Condition, - make_entity_state_attribute_condition, - make_entity_state_condition, -) +from homeassistant.helpers.automation import DomainSpec +from homeassistant.helpers.condition import Condition, make_entity_state_condition from .const import ATTR_ACTION, DOMAIN, HumidifierAction CONDITIONS: dict[str, type[Condition]] = { "is_off": make_entity_state_condition(DOMAIN, STATE_OFF), "is_on": make_entity_state_condition(DOMAIN, STATE_ON), - "is_drying": make_entity_state_attribute_condition( - DOMAIN, ATTR_ACTION, HumidifierAction.DRYING + "is_drying": make_entity_state_condition( + {DOMAIN: DomainSpec(value_source=ATTR_ACTION)}, HumidifierAction.DRYING ), - "is_humidifying": make_entity_state_attribute_condition( - DOMAIN, ATTR_ACTION, HumidifierAction.HUMIDIFYING + "is_humidifying": make_entity_state_condition( + {DOMAIN: DomainSpec(value_source=ATTR_ACTION)}, HumidifierAction.HUMIDIFYING ), } diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index 2c0505aceb2..8e8686b506f 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -54,7 +54,7 @@ from homeassistant.const import ( STATE_UNKNOWN, WEEKDAYS, ) -from homeassistant.core import HomeAssistant, State, callback +from homeassistant.core import HomeAssistant, State, callback, split_entity_id from homeassistant.exceptions import ( ConditionError, ConditionErrorContainer, @@ -361,6 +361,13 @@ class EntityConditionBase[DomainSpecT: DomainSpec = DomainSpec](Condition): """Filter entities matching any of the domain specs.""" return filter_by_domain_specs(self._hass, self._domain_specs, entities) + def _get_tracked_value(self, entity_state: State) -> Any: + """Get the tracked value from a state based on the DomainSpec.""" + domain_spec = self._domain_specs[split_entity_id(entity_state.entity_id)[0]] + if domain_spec.value_source is None: + return entity_state.state + return entity_state.attributes.get(domain_spec.value_source) + @abc.abstractmethod def is_valid_state(self, entity_state: State) -> bool: """Check if the state matches the expected state(s).""" @@ -410,13 +417,28 @@ class EntityStateConditionBase(EntityConditionBase): def is_valid_state(self, entity_state: State) -> bool: """Check if the state matches the expected state(s).""" - return entity_state.state in self._states + return self._get_tracked_value(entity_state) in self._states + + +def _normalize_domain_specs( + domain_specs: Mapping[str, DomainSpec] | str, +) -> Mapping[str, DomainSpec]: + """Normalize domain_specs argument to a Mapping.""" + if isinstance(domain_specs, str): + return {domain_specs: DomainSpec()} + return domain_specs def make_entity_state_condition( - domain: str, states: str | set[str] + domain_specs: Mapping[str, DomainSpec] | str, + states: str | set[str], ) -> type[EntityStateConditionBase]: - """Create a condition for entity state changes to specific state(s).""" + """Create a condition for entity state changes to specific state(s). + + domain_specs can be a string (domain name) for simple state-based conditions, + or a Mapping[str, DomainSpec] for attribute-based or multi-domain conditions. + """ + specs = _normalize_domain_specs(domain_specs) if isinstance(states, str): states_set = {states} @@ -426,43 +448,12 @@ def make_entity_state_condition( class CustomCondition(EntityStateConditionBase): """Condition for entity state.""" - _domain_specs = {domain: DomainSpec()} + _domain_specs = specs _states = states_set return CustomCondition -class EntityStateAttributeConditionBase(EntityConditionBase): - """State attribute condition.""" - - _attribute: str - _attribute_states: set[str] - - def is_valid_state(self, entity_state: State) -> bool: - """Check if the state matches the expected state(s).""" - return entity_state.attributes.get(self._attribute) in self._attribute_states - - -def make_entity_state_attribute_condition( - domain: str, attribute: str, attribute_states: str | set[str] -) -> type[EntityStateAttributeConditionBase]: - """Create a condition for entity attribute matching specific state(s).""" - - if isinstance(attribute_states, str): - attribute_states_set = {attribute_states} - else: - attribute_states_set = attribute_states - - class CustomCondition(EntityStateAttributeConditionBase): - """Condition for entity attribute.""" - - _domain_specs = {domain: DomainSpec()} - _attribute = attribute - _attribute_states = attribute_states_set - - return CustomCondition - - class ConditionProtocol(Protocol): """Define the format of condition modules."""