mirror of
https://github.com/home-assistant/core.git
synced 2026-04-02 00:20:30 +01:00
Refactor condition helpers (#165662)
This commit is contained in:
committed by
GitHub
parent
8bb51c0662
commit
ccecbcb389
@@ -1,11 +1,8 @@
|
||||
"""Provides conditions for climates."""
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.condition import (
|
||||
Condition,
|
||||
make_entity_state_attribute_condition,
|
||||
make_entity_state_condition,
|
||||
)
|
||||
from homeassistant.helpers.automation import DomainSpec
|
||||
from homeassistant.helpers.condition import Condition, make_entity_state_condition
|
||||
|
||||
from .const import ATTR_HVAC_ACTION, DOMAIN, HVACAction, HVACMode
|
||||
|
||||
@@ -22,14 +19,14 @@ CONDITIONS: dict[str, type[Condition]] = {
|
||||
HVACMode.HEAT_COOL,
|
||||
},
|
||||
),
|
||||
"is_cooling": make_entity_state_attribute_condition(
|
||||
DOMAIN, ATTR_HVAC_ACTION, HVACAction.COOLING
|
||||
"is_cooling": make_entity_state_condition(
|
||||
{DOMAIN: DomainSpec(value_source=ATTR_HVAC_ACTION)}, HVACAction.COOLING
|
||||
),
|
||||
"is_drying": make_entity_state_attribute_condition(
|
||||
DOMAIN, ATTR_HVAC_ACTION, HVACAction.DRYING
|
||||
"is_drying": make_entity_state_condition(
|
||||
{DOMAIN: DomainSpec(value_source=ATTR_HVAC_ACTION)}, HVACAction.DRYING
|
||||
),
|
||||
"is_heating": make_entity_state_attribute_condition(
|
||||
DOMAIN, ATTR_HVAC_ACTION, HVACAction.HEATING
|
||||
"is_heating": make_entity_state_condition(
|
||||
{DOMAIN: DomainSpec(value_source=ATTR_HVAC_ACTION)}, HVACAction.HEATING
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -2,22 +2,19 @@
|
||||
|
||||
from homeassistant.const import STATE_OFF, STATE_ON
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.condition import (
|
||||
Condition,
|
||||
make_entity_state_attribute_condition,
|
||||
make_entity_state_condition,
|
||||
)
|
||||
from homeassistant.helpers.automation import DomainSpec
|
||||
from homeassistant.helpers.condition import Condition, make_entity_state_condition
|
||||
|
||||
from .const import ATTR_ACTION, DOMAIN, HumidifierAction
|
||||
|
||||
CONDITIONS: dict[str, type[Condition]] = {
|
||||
"is_off": make_entity_state_condition(DOMAIN, STATE_OFF),
|
||||
"is_on": make_entity_state_condition(DOMAIN, STATE_ON),
|
||||
"is_drying": make_entity_state_attribute_condition(
|
||||
DOMAIN, ATTR_ACTION, HumidifierAction.DRYING
|
||||
"is_drying": make_entity_state_condition(
|
||||
{DOMAIN: DomainSpec(value_source=ATTR_ACTION)}, HumidifierAction.DRYING
|
||||
),
|
||||
"is_humidifying": make_entity_state_attribute_condition(
|
||||
DOMAIN, ATTR_ACTION, HumidifierAction.HUMIDIFYING
|
||||
"is_humidifying": make_entity_state_condition(
|
||||
{DOMAIN: DomainSpec(value_source=ATTR_ACTION)}, HumidifierAction.HUMIDIFYING
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ from homeassistant.const import (
|
||||
STATE_UNKNOWN,
|
||||
WEEKDAYS,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, State, callback
|
||||
from homeassistant.core import HomeAssistant, State, callback, split_entity_id
|
||||
from homeassistant.exceptions import (
|
||||
ConditionError,
|
||||
ConditionErrorContainer,
|
||||
@@ -361,6 +361,13 @@ class EntityConditionBase[DomainSpecT: DomainSpec = DomainSpec](Condition):
|
||||
"""Filter entities matching any of the domain specs."""
|
||||
return filter_by_domain_specs(self._hass, self._domain_specs, entities)
|
||||
|
||||
def _get_tracked_value(self, entity_state: State) -> Any:
|
||||
"""Get the tracked value from a state based on the DomainSpec."""
|
||||
domain_spec = self._domain_specs[split_entity_id(entity_state.entity_id)[0]]
|
||||
if domain_spec.value_source is None:
|
||||
return entity_state.state
|
||||
return entity_state.attributes.get(domain_spec.value_source)
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_valid_state(self, entity_state: State) -> bool:
|
||||
"""Check if the state matches the expected state(s)."""
|
||||
@@ -410,13 +417,28 @@ class EntityStateConditionBase(EntityConditionBase):
|
||||
|
||||
def is_valid_state(self, entity_state: State) -> bool:
|
||||
"""Check if the state matches the expected state(s)."""
|
||||
return entity_state.state in self._states
|
||||
return self._get_tracked_value(entity_state) in self._states
|
||||
|
||||
|
||||
def _normalize_domain_specs(
|
||||
domain_specs: Mapping[str, DomainSpec] | str,
|
||||
) -> Mapping[str, DomainSpec]:
|
||||
"""Normalize domain_specs argument to a Mapping."""
|
||||
if isinstance(domain_specs, str):
|
||||
return {domain_specs: DomainSpec()}
|
||||
return domain_specs
|
||||
|
||||
|
||||
def make_entity_state_condition(
|
||||
domain: str, states: str | set[str]
|
||||
domain_specs: Mapping[str, DomainSpec] | str,
|
||||
states: str | set[str],
|
||||
) -> type[EntityStateConditionBase]:
|
||||
"""Create a condition for entity state changes to specific state(s)."""
|
||||
"""Create a condition for entity state changes to specific state(s).
|
||||
|
||||
domain_specs can be a string (domain name) for simple state-based conditions,
|
||||
or a Mapping[str, DomainSpec] for attribute-based or multi-domain conditions.
|
||||
"""
|
||||
specs = _normalize_domain_specs(domain_specs)
|
||||
|
||||
if isinstance(states, str):
|
||||
states_set = {states}
|
||||
@@ -426,43 +448,12 @@ def make_entity_state_condition(
|
||||
class CustomCondition(EntityStateConditionBase):
|
||||
"""Condition for entity state."""
|
||||
|
||||
_domain_specs = {domain: DomainSpec()}
|
||||
_domain_specs = specs
|
||||
_states = states_set
|
||||
|
||||
return CustomCondition
|
||||
|
||||
|
||||
class EntityStateAttributeConditionBase(EntityConditionBase):
|
||||
"""State attribute condition."""
|
||||
|
||||
_attribute: str
|
||||
_attribute_states: set[str]
|
||||
|
||||
def is_valid_state(self, entity_state: State) -> bool:
|
||||
"""Check if the state matches the expected state(s)."""
|
||||
return entity_state.attributes.get(self._attribute) in self._attribute_states
|
||||
|
||||
|
||||
def make_entity_state_attribute_condition(
|
||||
domain: str, attribute: str, attribute_states: str | set[str]
|
||||
) -> type[EntityStateAttributeConditionBase]:
|
||||
"""Create a condition for entity attribute matching specific state(s)."""
|
||||
|
||||
if isinstance(attribute_states, str):
|
||||
attribute_states_set = {attribute_states}
|
||||
else:
|
||||
attribute_states_set = attribute_states
|
||||
|
||||
class CustomCondition(EntityStateAttributeConditionBase):
|
||||
"""Condition for entity attribute."""
|
||||
|
||||
_domain_specs = {domain: DomainSpec()}
|
||||
_attribute = attribute
|
||||
_attribute_states = attribute_states_set
|
||||
|
||||
return CustomCondition
|
||||
|
||||
|
||||
class ConditionProtocol(Protocol):
|
||||
"""Define the format of condition modules."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user