1
0
mirror of https://github.com/home-assistant/core.git synced 2026-04-02 00:20:30 +01:00

Refactor condition helpers (#165662)

This commit is contained in:
Ariel Ebersberger
2026-03-16 16:57:53 +01:00
committed by GitHub
parent 8bb51c0662
commit ccecbcb389
3 changed files with 41 additions and 56 deletions

View File

@@ -1,11 +1,8 @@
"""Provides conditions for climates."""
from homeassistant.core import HomeAssistant
from homeassistant.helpers.condition import (
Condition,
make_entity_state_attribute_condition,
make_entity_state_condition,
)
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.condition import Condition, make_entity_state_condition
from .const import ATTR_HVAC_ACTION, DOMAIN, HVACAction, HVACMode
@@ -22,14 +19,14 @@ CONDITIONS: dict[str, type[Condition]] = {
HVACMode.HEAT_COOL,
},
),
"is_cooling": make_entity_state_attribute_condition(
DOMAIN, ATTR_HVAC_ACTION, HVACAction.COOLING
"is_cooling": make_entity_state_condition(
{DOMAIN: DomainSpec(value_source=ATTR_HVAC_ACTION)}, HVACAction.COOLING
),
"is_drying": make_entity_state_attribute_condition(
DOMAIN, ATTR_HVAC_ACTION, HVACAction.DRYING
"is_drying": make_entity_state_condition(
{DOMAIN: DomainSpec(value_source=ATTR_HVAC_ACTION)}, HVACAction.DRYING
),
"is_heating": make_entity_state_attribute_condition(
DOMAIN, ATTR_HVAC_ACTION, HVACAction.HEATING
"is_heating": make_entity_state_condition(
{DOMAIN: DomainSpec(value_source=ATTR_HVAC_ACTION)}, HVACAction.HEATING
),
}

View File

@@ -2,22 +2,19 @@
from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant
from homeassistant.helpers.condition import (
Condition,
make_entity_state_attribute_condition,
make_entity_state_condition,
)
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.condition import Condition, make_entity_state_condition
from .const import ATTR_ACTION, DOMAIN, HumidifierAction
CONDITIONS: dict[str, type[Condition]] = {
"is_off": make_entity_state_condition(DOMAIN, STATE_OFF),
"is_on": make_entity_state_condition(DOMAIN, STATE_ON),
"is_drying": make_entity_state_attribute_condition(
DOMAIN, ATTR_ACTION, HumidifierAction.DRYING
"is_drying": make_entity_state_condition(
{DOMAIN: DomainSpec(value_source=ATTR_ACTION)}, HumidifierAction.DRYING
),
"is_humidifying": make_entity_state_attribute_condition(
DOMAIN, ATTR_ACTION, HumidifierAction.HUMIDIFYING
"is_humidifying": make_entity_state_condition(
{DOMAIN: DomainSpec(value_source=ATTR_ACTION)}, HumidifierAction.HUMIDIFYING
),
}

View File

@@ -54,7 +54,7 @@ from homeassistant.const import (
STATE_UNKNOWN,
WEEKDAYS,
)
from homeassistant.core import HomeAssistant, State, callback
from homeassistant.core import HomeAssistant, State, callback, split_entity_id
from homeassistant.exceptions import (
ConditionError,
ConditionErrorContainer,
@@ -361,6 +361,13 @@ class EntityConditionBase[DomainSpecT: DomainSpec = DomainSpec](Condition):
"""Filter entities matching any of the domain specs."""
return filter_by_domain_specs(self._hass, self._domain_specs, entities)
def _get_tracked_value(self, entity_state: State) -> Any:
"""Get the tracked value from a state based on the DomainSpec."""
domain_spec = self._domain_specs[split_entity_id(entity_state.entity_id)[0]]
if domain_spec.value_source is None:
return entity_state.state
return entity_state.attributes.get(domain_spec.value_source)
@abc.abstractmethod
def is_valid_state(self, entity_state: State) -> bool:
"""Check if the state matches the expected state(s)."""
@@ -410,13 +417,28 @@ class EntityStateConditionBase(EntityConditionBase):
def is_valid_state(self, entity_state: State) -> bool:
"""Check if the state matches the expected state(s)."""
return entity_state.state in self._states
return self._get_tracked_value(entity_state) in self._states
def _normalize_domain_specs(
domain_specs: Mapping[str, DomainSpec] | str,
) -> Mapping[str, DomainSpec]:
"""Normalize domain_specs argument to a Mapping."""
if isinstance(domain_specs, str):
return {domain_specs: DomainSpec()}
return domain_specs
def make_entity_state_condition(
domain: str, states: str | set[str]
domain_specs: Mapping[str, DomainSpec] | str,
states: str | set[str],
) -> type[EntityStateConditionBase]:
"""Create a condition for entity state changes to specific state(s)."""
"""Create a condition for entity state changes to specific state(s).
domain_specs can be a string (domain name) for simple state-based conditions,
or a Mapping[str, DomainSpec] for attribute-based or multi-domain conditions.
"""
specs = _normalize_domain_specs(domain_specs)
if isinstance(states, str):
states_set = {states}
@@ -426,43 +448,12 @@ def make_entity_state_condition(
class CustomCondition(EntityStateConditionBase):
"""Condition for entity state."""
_domain_specs = {domain: DomainSpec()}
_domain_specs = specs
_states = states_set
return CustomCondition
class EntityStateAttributeConditionBase(EntityConditionBase):
"""State attribute condition."""
_attribute: str
_attribute_states: set[str]
def is_valid_state(self, entity_state: State) -> bool:
"""Check if the state matches the expected state(s)."""
return entity_state.attributes.get(self._attribute) in self._attribute_states
def make_entity_state_attribute_condition(
domain: str, attribute: str, attribute_states: str | set[str]
) -> type[EntityStateAttributeConditionBase]:
"""Create a condition for entity attribute matching specific state(s)."""
if isinstance(attribute_states, str):
attribute_states_set = {attribute_states}
else:
attribute_states_set = attribute_states
class CustomCondition(EntityStateAttributeConditionBase):
"""Condition for entity attribute."""
_domain_specs = {domain: DomainSpec()}
_attribute = attribute
_attribute_states = attribute_states_set
return CustomCondition
class ConditionProtocol(Protocol):
"""Define the format of condition modules."""