mirror of
https://github.com/home-assistant/core.git
synced 2026-04-17 23:53:49 +01:00
Add DomainSpec to trigger and condition helpers (#165392)
This commit is contained in:
committed by
GitHub
parent
d96191723f
commit
eb17367229
@@ -2,6 +2,7 @@
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.automation import DomainSpec
|
||||
from homeassistant.helpers.condition import (
|
||||
Condition,
|
||||
EntityStateConditionBase,
|
||||
@@ -43,7 +44,7 @@ def make_entity_state_required_features_condition(
|
||||
class CustomCondition(EntityStateRequiredFeaturesCondition):
|
||||
"""Condition for entity state changes."""
|
||||
|
||||
_domain = domain
|
||||
_domain_specs = {domain: DomainSpec()}
|
||||
_states = {to_state}
|
||||
_required_features = required_features
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.automation import DomainSpec
|
||||
from homeassistant.helpers.entity import get_supported_features
|
||||
from homeassistant.helpers.trigger import (
|
||||
EntityTargetStateTriggerBase,
|
||||
@@ -44,7 +45,7 @@ def make_entity_state_trigger_required_features(
|
||||
class CustomTrigger(EntityStateTriggerRequiredFeatures):
|
||||
"""Trigger for entity state changes."""
|
||||
|
||||
_domains = {domain}
|
||||
_domain_specs = {domain: DomainSpec()}
|
||||
_to_states = {to_state}
|
||||
_required_features = required_features
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers.automation import DomainSpec
|
||||
from homeassistant.helpers.trigger import (
|
||||
ENTITY_STATE_TRIGGER_SCHEMA,
|
||||
EntityTriggerBase,
|
||||
@@ -14,7 +15,7 @@ from . import DOMAIN
|
||||
class ButtonPressedTrigger(EntityTriggerBase):
|
||||
"""Trigger for button entity presses."""
|
||||
|
||||
_domains = {DOMAIN}
|
||||
_domain_specs = {DOMAIN: DomainSpec()}
|
||||
_schema = ENTITY_STATE_TRIGGER_SCHEMA
|
||||
|
||||
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
|
||||
|
||||
@@ -5,13 +5,14 @@ import voluptuous as vol
|
||||
from homeassistant.const import ATTR_TEMPERATURE, CONF_OPTIONS
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.automation import DomainSpec, NumericalDomainSpec
|
||||
from homeassistant.helpers.trigger import (
|
||||
ENTITY_STATE_TRIGGER_SCHEMA_FIRST_LAST,
|
||||
EntityTargetStateTriggerBase,
|
||||
Trigger,
|
||||
TriggerConfig,
|
||||
make_entity_numerical_state_attribute_changed_trigger,
|
||||
make_entity_numerical_state_attribute_crossed_threshold_trigger,
|
||||
make_entity_numerical_state_changed_trigger,
|
||||
make_entity_numerical_state_crossed_threshold_trigger,
|
||||
make_entity_target_state_attribute_trigger,
|
||||
make_entity_target_state_trigger,
|
||||
make_entity_transition_trigger,
|
||||
@@ -35,7 +36,7 @@ HVAC_MODE_CHANGED_TRIGGER_SCHEMA = ENTITY_STATE_TRIGGER_SCHEMA_FIRST_LAST.extend
|
||||
class HVACModeChangedTrigger(EntityTargetStateTriggerBase):
|
||||
"""Trigger for entity state changes."""
|
||||
|
||||
_domains = {DOMAIN}
|
||||
_domain_specs = {DOMAIN: DomainSpec()}
|
||||
_schema = HVAC_MODE_CHANGED_TRIGGER_SCHEMA
|
||||
|
||||
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
|
||||
@@ -52,17 +53,17 @@ TRIGGERS: dict[str, type[Trigger]] = {
|
||||
"started_drying": make_entity_target_state_attribute_trigger(
|
||||
DOMAIN, ATTR_HVAC_ACTION, HVACAction.DRYING
|
||||
),
|
||||
"target_humidity_changed": make_entity_numerical_state_attribute_changed_trigger(
|
||||
{DOMAIN}, {DOMAIN: ATTR_HUMIDITY}
|
||||
"target_humidity_changed": make_entity_numerical_state_changed_trigger(
|
||||
{DOMAIN: NumericalDomainSpec(value_source=ATTR_HUMIDITY)}
|
||||
),
|
||||
"target_humidity_crossed_threshold": make_entity_numerical_state_attribute_crossed_threshold_trigger(
|
||||
{DOMAIN}, {DOMAIN: ATTR_HUMIDITY}
|
||||
"target_humidity_crossed_threshold": make_entity_numerical_state_crossed_threshold_trigger(
|
||||
{DOMAIN: NumericalDomainSpec(value_source=ATTR_HUMIDITY)}
|
||||
),
|
||||
"target_temperature_changed": make_entity_numerical_state_attribute_changed_trigger(
|
||||
{DOMAIN}, {DOMAIN: ATTR_TEMPERATURE}
|
||||
"target_temperature_changed": make_entity_numerical_state_changed_trigger(
|
||||
{DOMAIN: NumericalDomainSpec(value_source=ATTR_TEMPERATURE)}
|
||||
),
|
||||
"target_temperature_crossed_threshold": make_entity_numerical_state_attribute_crossed_threshold_trigger(
|
||||
{DOMAIN}, {DOMAIN: ATTR_TEMPERATURE}
|
||||
"target_temperature_crossed_threshold": make_entity_numerical_state_crossed_threshold_trigger(
|
||||
{DOMAIN: NumericalDomainSpec(value_source=ATTR_TEMPERATURE)}
|
||||
),
|
||||
"turned_off": make_entity_target_state_trigger(DOMAIN, HVACMode.OFF),
|
||||
"turned_on": make_entity_transition_trigger(
|
||||
|
||||
@@ -1,81 +1,82 @@
|
||||
"""Provides triggers for covers."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant, State, split_entity_id
|
||||
from homeassistant.helpers.trigger import (
|
||||
EntityTriggerBase,
|
||||
Trigger,
|
||||
get_device_class_or_undefined,
|
||||
)
|
||||
from homeassistant.helpers.automation import DomainSpec
|
||||
from homeassistant.helpers.trigger import EntityTriggerBase, Trigger
|
||||
|
||||
from .const import ATTR_IS_CLOSED, DOMAIN, CoverDeviceClass
|
||||
|
||||
|
||||
class CoverTriggerBase(EntityTriggerBase):
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class CoverDomainSpec(DomainSpec):
|
||||
"""DomainSpec with a target value for comparison."""
|
||||
|
||||
target_value: str | bool | None = None
|
||||
|
||||
|
||||
class CoverTriggerBase(EntityTriggerBase[CoverDomainSpec]):
|
||||
"""Base trigger for cover state changes."""
|
||||
|
||||
_binary_sensor_target_state: str
|
||||
_cover_is_closed_target_value: bool
|
||||
_device_classes: dict[str, str]
|
||||
|
||||
def entity_filter(self, entities: set[str]) -> set[str]:
|
||||
"""Filter entities by cover device class."""
|
||||
entities = super().entity_filter(entities)
|
||||
return {
|
||||
entity_id
|
||||
for entity_id in entities
|
||||
if get_device_class_or_undefined(self._hass, entity_id)
|
||||
== self._device_classes[split_entity_id(entity_id)[0]]
|
||||
}
|
||||
def _get_value(self, state: State) -> str | bool | None:
|
||||
"""Extract the relevant value from state based on domain spec."""
|
||||
domain_spec = self._domain_specs[split_entity_id(state.entity_id)[0]]
|
||||
if domain_spec.value_source is not None:
|
||||
return state.attributes.get(domain_spec.value_source)
|
||||
return state.state
|
||||
|
||||
def is_valid_state(self, state: State) -> bool:
|
||||
"""Check if the state matches the target cover state."""
|
||||
if split_entity_id(state.entity_id)[0] == DOMAIN:
|
||||
return (
|
||||
state.attributes.get(ATTR_IS_CLOSED)
|
||||
== self._cover_is_closed_target_value
|
||||
)
|
||||
return state.state == self._binary_sensor_target_state
|
||||
domain_spec = self._domain_specs[split_entity_id(state.entity_id)[0]]
|
||||
return self._get_value(state) == domain_spec.target_value
|
||||
|
||||
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
|
||||
"""Check if the transition is valid for a cover state change."""
|
||||
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
|
||||
return False
|
||||
if split_entity_id(from_state.entity_id)[0] == DOMAIN:
|
||||
if (from_is_closed := from_state.attributes.get(ATTR_IS_CLOSED)) is None:
|
||||
return False
|
||||
return from_is_closed != to_state.attributes.get(ATTR_IS_CLOSED) # type: ignore[no-any-return]
|
||||
return from_state.state != to_state.state
|
||||
if (from_value := self._get_value(from_state)) is None:
|
||||
return False
|
||||
return from_value != self._get_value(to_state)
|
||||
|
||||
|
||||
def make_cover_opened_trigger(
|
||||
*, device_classes: dict[str, str], domains: set[str] | None = None
|
||||
*, device_classes: dict[str, str]
|
||||
) -> type[CoverTriggerBase]:
|
||||
"""Create a trigger cover_opened."""
|
||||
|
||||
class CoverOpenedTrigger(CoverTriggerBase):
|
||||
"""Trigger for cover opened state changes."""
|
||||
|
||||
_binary_sensor_target_state = STATE_ON
|
||||
_cover_is_closed_target_value = False
|
||||
_domains = domains or {DOMAIN}
|
||||
_device_classes = device_classes
|
||||
_domain_specs = {
|
||||
domain: CoverDomainSpec(
|
||||
device_class=dc,
|
||||
value_source=ATTR_IS_CLOSED if domain == DOMAIN else None,
|
||||
target_value=False if domain == DOMAIN else STATE_ON,
|
||||
)
|
||||
for domain, dc in device_classes.items()
|
||||
}
|
||||
|
||||
return CoverOpenedTrigger
|
||||
|
||||
|
||||
def make_cover_closed_trigger(
|
||||
*, device_classes: dict[str, str], domains: set[str] | None = None
|
||||
*, device_classes: dict[str, str]
|
||||
) -> type[CoverTriggerBase]:
|
||||
"""Create a trigger cover_closed."""
|
||||
|
||||
class CoverClosedTrigger(CoverTriggerBase):
|
||||
"""Trigger for cover closed state changes."""
|
||||
|
||||
_binary_sensor_target_state = STATE_OFF
|
||||
_cover_is_closed_target_value = True
|
||||
_domains = domains or {DOMAIN}
|
||||
_device_classes = device_classes
|
||||
_domain_specs = {
|
||||
domain: CoverDomainSpec(
|
||||
device_class=dc,
|
||||
value_source=ATTR_IS_CLOSED if domain == DOMAIN else None,
|
||||
target_value=True if domain == DOMAIN else STATE_OFF,
|
||||
)
|
||||
for domain, dc in device_classes.items()
|
||||
}
|
||||
|
||||
return CoverClosedTrigger
|
||||
|
||||
|
||||
@@ -20,14 +20,8 @@ DEVICE_CLASSES_DOOR: dict[str, str] = {
|
||||
|
||||
|
||||
TRIGGERS: dict[str, type[Trigger]] = {
|
||||
"opened": make_cover_opened_trigger(
|
||||
device_classes=DEVICE_CLASSES_DOOR,
|
||||
domains={BINARY_SENSOR_DOMAIN, COVER_DOMAIN},
|
||||
),
|
||||
"closed": make_cover_closed_trigger(
|
||||
device_classes=DEVICE_CLASSES_DOOR,
|
||||
domains={BINARY_SENSOR_DOMAIN, COVER_DOMAIN},
|
||||
),
|
||||
"opened": make_cover_opened_trigger(device_classes=DEVICE_CLASSES_DOOR),
|
||||
"closed": make_cover_closed_trigger(device_classes=DEVICE_CLASSES_DOOR),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -20,14 +20,8 @@ DEVICE_CLASSES_GARAGE_DOOR: dict[str, str] = {
|
||||
|
||||
|
||||
TRIGGERS: dict[str, type[Trigger]] = {
|
||||
"opened": make_cover_opened_trigger(
|
||||
device_classes=DEVICE_CLASSES_GARAGE_DOOR,
|
||||
domains={BINARY_SENSOR_DOMAIN, COVER_DOMAIN},
|
||||
),
|
||||
"closed": make_cover_closed_trigger(
|
||||
device_classes=DEVICE_CLASSES_GARAGE_DOOR,
|
||||
domains={BINARY_SENSOR_DOMAIN, COVER_DOMAIN},
|
||||
),
|
||||
"opened": make_cover_opened_trigger(device_classes=DEVICE_CLASSES_GARAGE_DOOR),
|
||||
"closed": make_cover_closed_trigger(device_classes=DEVICE_CLASSES_GARAGE_DOOR),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -15,50 +15,43 @@ from homeassistant.components.weather import (
|
||||
ATTR_WEATHER_HUMIDITY,
|
||||
DOMAIN as WEATHER_DOMAIN,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, split_entity_id
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.automation import NumericalDomainSpec
|
||||
from homeassistant.helpers.trigger import (
|
||||
EntityNumericalStateAttributeChangedTriggerBase,
|
||||
EntityNumericalStateAttributeCrossedThresholdTriggerBase,
|
||||
EntityTriggerBase,
|
||||
Trigger,
|
||||
get_device_class_or_undefined,
|
||||
)
|
||||
|
||||
|
||||
class _HumidityTriggerMixin(EntityTriggerBase):
|
||||
"""Mixin for humidity triggers providing entity filtering and value extraction."""
|
||||
|
||||
_attributes = {
|
||||
CLIMATE_DOMAIN: CLIMATE_ATTR_CURRENT_HUMIDITY,
|
||||
HUMIDIFIER_DOMAIN: HUMIDIFIER_ATTR_CURRENT_HUMIDITY,
|
||||
SENSOR_DOMAIN: None, # Use state.state
|
||||
WEATHER_DOMAIN: ATTR_WEATHER_HUMIDITY,
|
||||
}
|
||||
_domains = {SENSOR_DOMAIN, CLIMATE_DOMAIN, HUMIDIFIER_DOMAIN, WEATHER_DOMAIN}
|
||||
|
||||
def entity_filter(self, entities: set[str]) -> set[str]:
|
||||
"""Filter entities: all climate/humidifier/weather, sensor only with device_class humidity."""
|
||||
entities = super().entity_filter(entities)
|
||||
return {
|
||||
entity_id
|
||||
for entity_id in entities
|
||||
if split_entity_id(entity_id)[0] != SENSOR_DOMAIN
|
||||
or get_device_class_or_undefined(self._hass, entity_id)
|
||||
== SensorDeviceClass.HUMIDITY
|
||||
}
|
||||
HUMIDITY_DOMAIN_SPECS: dict[str, NumericalDomainSpec] = {
|
||||
CLIMATE_DOMAIN: NumericalDomainSpec(
|
||||
value_source=CLIMATE_ATTR_CURRENT_HUMIDITY,
|
||||
),
|
||||
HUMIDIFIER_DOMAIN: NumericalDomainSpec(
|
||||
value_source=HUMIDIFIER_ATTR_CURRENT_HUMIDITY,
|
||||
),
|
||||
SENSOR_DOMAIN: NumericalDomainSpec(
|
||||
device_class=SensorDeviceClass.HUMIDITY,
|
||||
),
|
||||
WEATHER_DOMAIN: NumericalDomainSpec(
|
||||
value_source=ATTR_WEATHER_HUMIDITY,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class HumidityChangedTrigger(
|
||||
_HumidityTriggerMixin, EntityNumericalStateAttributeChangedTriggerBase
|
||||
):
|
||||
class HumidityChangedTrigger(EntityNumericalStateAttributeChangedTriggerBase):
|
||||
"""Trigger for humidity value changes across multiple domains."""
|
||||
|
||||
_domain_specs = HUMIDITY_DOMAIN_SPECS
|
||||
|
||||
|
||||
class HumidityCrossedThresholdTrigger(
|
||||
_HumidityTriggerMixin, EntityNumericalStateAttributeCrossedThresholdTriggerBase
|
||||
EntityNumericalStateAttributeCrossedThresholdTriggerBase
|
||||
):
|
||||
"""Trigger for humidity value crossing a threshold across multiple domains."""
|
||||
|
||||
_domain_specs = HUMIDITY_DOMAIN_SPECS
|
||||
|
||||
|
||||
TRIGGERS: dict[str, type[Trigger]] = {
|
||||
"changed": HumidityChangedTrigger,
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any
|
||||
|
||||
from homeassistant.const import STATE_OFF, STATE_ON
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.automation import NumericalDomainSpec
|
||||
from homeassistant.helpers.trigger import (
|
||||
EntityNumericalStateAttributeChangedTriggerBase,
|
||||
EntityNumericalStateAttributeCrossedThresholdTriggerBase,
|
||||
@@ -20,13 +21,18 @@ def _convert_uint8_to_percentage(value: Any) -> float:
|
||||
return (float(value) / 255.0) * 100.0
|
||||
|
||||
|
||||
BRIGHTNESS_DOMAIN_SPECS = {
|
||||
DOMAIN: NumericalDomainSpec(
|
||||
value_source=ATTR_BRIGHTNESS,
|
||||
value_converter=_convert_uint8_to_percentage,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class BrightnessChangedTrigger(EntityNumericalStateAttributeChangedTriggerBase):
|
||||
"""Trigger for brightness changed."""
|
||||
|
||||
_domains = {DOMAIN}
|
||||
_attributes = {DOMAIN: ATTR_BRIGHTNESS}
|
||||
|
||||
_converter = staticmethod(_convert_uint8_to_percentage)
|
||||
_domain_specs = BRIGHTNESS_DOMAIN_SPECS
|
||||
|
||||
|
||||
class BrightnessCrossedThresholdTrigger(
|
||||
@@ -34,9 +40,7 @@ class BrightnessCrossedThresholdTrigger(
|
||||
):
|
||||
"""Trigger for brightness crossed threshold."""
|
||||
|
||||
_domains = {DOMAIN}
|
||||
_attributes = {DOMAIN: ATTR_BRIGHTNESS}
|
||||
_converter = staticmethod(_convert_uint8_to_percentage)
|
||||
_domain_specs = BRIGHTNESS_DOMAIN_SPECS
|
||||
|
||||
|
||||
TRIGGERS: dict[str, type[Trigger]] = {
|
||||
|
||||
@@ -6,28 +6,20 @@ from homeassistant.components.binary_sensor import (
|
||||
)
|
||||
from homeassistant.const import STATE_OFF, STATE_ON
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.automation import DomainSpec
|
||||
from homeassistant.helpers.trigger import (
|
||||
EntityTargetStateTriggerBase,
|
||||
EntityTriggerBase,
|
||||
Trigger,
|
||||
get_device_class_or_undefined,
|
||||
)
|
||||
|
||||
|
||||
class _MotionBinaryTriggerBase(EntityTriggerBase):
|
||||
"""Base trigger for motion binary sensor state changes."""
|
||||
|
||||
_domains = {BINARY_SENSOR_DOMAIN}
|
||||
|
||||
def entity_filter(self, entities: set[str]) -> set[str]:
|
||||
"""Filter entities by motion device class."""
|
||||
entities = super().entity_filter(entities)
|
||||
return {
|
||||
entity_id
|
||||
for entity_id in entities
|
||||
if get_device_class_or_undefined(self._hass, entity_id)
|
||||
== BinarySensorDeviceClass.MOTION
|
||||
}
|
||||
_domain_specs = {
|
||||
BINARY_SENSOR_DOMAIN: DomainSpec(device_class=BinarySensorDeviceClass.MOTION)
|
||||
}
|
||||
|
||||
|
||||
class MotionDetectedTrigger(_MotionBinaryTriggerBase, EntityTargetStateTriggerBase):
|
||||
|
||||
@@ -6,28 +6,20 @@ from homeassistant.components.binary_sensor import (
|
||||
)
|
||||
from homeassistant.const import STATE_OFF, STATE_ON
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.automation import DomainSpec
|
||||
from homeassistant.helpers.trigger import (
|
||||
EntityTargetStateTriggerBase,
|
||||
EntityTriggerBase,
|
||||
Trigger,
|
||||
get_device_class_or_undefined,
|
||||
)
|
||||
|
||||
|
||||
class _OccupancyBinaryTriggerBase(EntityTriggerBase):
|
||||
"""Base trigger for occupancy binary sensor state changes."""
|
||||
|
||||
_domains = {BINARY_SENSOR_DOMAIN}
|
||||
|
||||
def entity_filter(self, entities: set[str]) -> set[str]:
|
||||
"""Filter entities by occupancy device class."""
|
||||
entities = super().entity_filter(entities)
|
||||
return {
|
||||
entity_id
|
||||
for entity_id in entities
|
||||
if get_device_class_or_undefined(self._hass, entity_id)
|
||||
== BinarySensorDeviceClass.OCCUPANCY
|
||||
}
|
||||
_domain_specs = {
|
||||
BINARY_SENSOR_DOMAIN: DomainSpec(device_class=BinarySensorDeviceClass.OCCUPANCY)
|
||||
}
|
||||
|
||||
|
||||
class OccupancyDetectedTrigger(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers.automation import DomainSpec
|
||||
from homeassistant.helpers.trigger import (
|
||||
ENTITY_STATE_TRIGGER_SCHEMA,
|
||||
EntityTriggerBase,
|
||||
@@ -14,7 +15,7 @@ from . import DOMAIN
|
||||
class SceneActivatedTrigger(EntityTriggerBase):
|
||||
"""Trigger for scene entity activations."""
|
||||
|
||||
_domains = {DOMAIN}
|
||||
_domain_specs = {DOMAIN: DomainSpec()}
|
||||
_schema = ENTITY_STATE_TRIGGER_SCHEMA
|
||||
|
||||
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers.automation import DomainSpec
|
||||
from homeassistant.helpers.trigger import (
|
||||
EntityTransitionTriggerBase,
|
||||
Trigger,
|
||||
@@ -14,7 +15,7 @@ from .const import ATTR_NEXT_EVENT, DOMAIN
|
||||
class ScheduleBackToBackTrigger(EntityTransitionTriggerBase):
|
||||
"""Trigger for back-to-back schedule blocks."""
|
||||
|
||||
_domains = {DOMAIN}
|
||||
_domain_specs = {DOMAIN: DomainSpec()}
|
||||
_from_states = {STATE_OFF, STATE_ON}
|
||||
_to_states = {STATE_ON}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers.automation import DomainSpec
|
||||
from homeassistant.helpers.trigger import (
|
||||
ENTITY_STATE_TRIGGER_SCHEMA,
|
||||
EntityTriggerBase,
|
||||
@@ -14,7 +15,7 @@ from .const import DOMAIN
|
||||
class TextChangedTrigger(EntityTriggerBase):
|
||||
"""Trigger for text entity when its content changes."""
|
||||
|
||||
_domains = {DOMAIN}
|
||||
_domain_specs = {DOMAIN: DomainSpec()}
|
||||
_schema = ENTITY_STATE_TRIGGER_SCHEMA
|
||||
|
||||
def is_valid_state(self, state: State) -> bool:
|
||||
|
||||
@@ -20,14 +20,8 @@ DEVICE_CLASSES_WINDOW: dict[str, str] = {
|
||||
|
||||
|
||||
TRIGGERS: dict[str, type[Trigger]] = {
|
||||
"opened": make_cover_opened_trigger(
|
||||
device_classes=DEVICE_CLASSES_WINDOW,
|
||||
domains={BINARY_SENSOR_DOMAIN, COVER_DOMAIN},
|
||||
),
|
||||
"closed": make_cover_closed_trigger(
|
||||
device_classes=DEVICE_CLASSES_WINDOW,
|
||||
domains={BINARY_SENSOR_DOMAIN, COVER_DOMAIN},
|
||||
),
|
||||
"opened": make_cover_opened_trigger(device_classes=DEVICE_CLASSES_WINDOW),
|
||||
"closed": make_cover_closed_trigger(device_classes=DEVICE_CLASSES_WINDOW),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,68 @@
|
||||
"""Helpers for automation."""
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import CONF_OPTIONS
|
||||
from homeassistant.core import HomeAssistant, split_entity_id
|
||||
|
||||
from .entity import get_device_class_or_undefined
|
||||
from .typing import ConfigType
|
||||
|
||||
|
||||
class AnyDeviceClassType(Enum):
|
||||
"""Singleton type for matching any device class."""
|
||||
|
||||
_singleton = 0
|
||||
|
||||
|
||||
ANY_DEVICE_CLASS = AnyDeviceClassType._singleton # noqa: SLF001
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DomainSpec:
|
||||
"""Describes how to match and extract a value from an entity.
|
||||
|
||||
Used by triggers and conditions.
|
||||
"""
|
||||
|
||||
device_class: str | None | AnyDeviceClassType = ANY_DEVICE_CLASS
|
||||
value_source: str | None = None
|
||||
"""Attribute name to extract the value from, or None for state.state."""
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class NumericalDomainSpec(DomainSpec):
|
||||
"""DomainSpec with an optional value converter for numerical triggers."""
|
||||
|
||||
value_converter: Callable[[Any], float] | None = None
|
||||
"""Optional converter for numerical values (e.g. uint8 → percentage)."""
|
||||
|
||||
|
||||
def filter_by_domain_specs(
|
||||
hass: HomeAssistant,
|
||||
domain_specs: Mapping[str, DomainSpec],
|
||||
entities: set[str],
|
||||
) -> set[str]:
|
||||
"""Filter entities matching any of the domain specs."""
|
||||
result: set[str] = set()
|
||||
for entity_id in entities:
|
||||
if not (domain_spec := domain_specs.get(split_entity_id(entity_id)[0])):
|
||||
continue
|
||||
if (
|
||||
domain_spec.device_class is not ANY_DEVICE_CLASS
|
||||
and get_device_class_or_undefined(hass, entity_id)
|
||||
!= domain_spec.device_class
|
||||
):
|
||||
continue
|
||||
result.add(entity_id)
|
||||
return result
|
||||
|
||||
|
||||
def get_absolute_description_key(domain: str, key: str) -> str:
|
||||
"""Return the absolute description key."""
|
||||
if not key.startswith("_"):
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Container, Coroutine, Generator, Iterable
|
||||
from collections.abc import Callable, Container, Coroutine, Generator, Iterable, Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, time as dt_time, timedelta
|
||||
@@ -54,7 +54,7 @@ from homeassistant.const import (
|
||||
STATE_UNKNOWN,
|
||||
WEEKDAYS,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, State, callback, split_entity_id
|
||||
from homeassistant.core import HomeAssistant, State, callback
|
||||
from homeassistant.exceptions import (
|
||||
ConditionError,
|
||||
ConditionErrorContainer,
|
||||
@@ -76,6 +76,8 @@ from homeassistant.util.yaml import load_yaml_dict
|
||||
|
||||
from . import config_validation as cv, entity_registry as er, selector
|
||||
from .automation import (
|
||||
DomainSpec,
|
||||
filter_by_domain_specs,
|
||||
get_absolute_description_key,
|
||||
get_relative_description_key,
|
||||
move_options_fields_to_top_level,
|
||||
@@ -332,10 +334,10 @@ ENTITY_STATE_CONDITION_SCHEMA_ANY_ALL = vol.Schema(
|
||||
)
|
||||
|
||||
|
||||
class EntityConditionBase(Condition):
|
||||
class EntityConditionBase[DomainSpecT: DomainSpec = DomainSpec](Condition):
|
||||
"""Base class for entity conditions."""
|
||||
|
||||
_domain: str
|
||||
_domain_specs: Mapping[str, DomainSpecT]
|
||||
_schema: vol.Schema = ENTITY_STATE_CONDITION_SCHEMA_ANY_ALL
|
||||
|
||||
@override
|
||||
@@ -356,12 +358,8 @@ class EntityConditionBase(Condition):
|
||||
self._behavior = config.options[ATTR_BEHAVIOR]
|
||||
|
||||
def entity_filter(self, entities: set[str]) -> set[str]:
|
||||
"""Filter entities of this domain."""
|
||||
return {
|
||||
entity_id
|
||||
for entity_id in entities
|
||||
if split_entity_id(entity_id)[0] == self._domain
|
||||
}
|
||||
"""Filter entities matching any of the domain specs."""
|
||||
return filter_by_domain_specs(self._hass, self._domain_specs, entities)
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_valid_state(self, entity_state: State) -> bool:
|
||||
@@ -428,7 +426,7 @@ def make_entity_state_condition(
|
||||
class CustomCondition(EntityStateConditionBase):
|
||||
"""Condition for entity state."""
|
||||
|
||||
_domain = domain
|
||||
_domain_specs = {domain: DomainSpec()}
|
||||
_states = states_set
|
||||
|
||||
return CustomCondition
|
||||
@@ -458,7 +456,7 @@ def make_entity_state_attribute_condition(
|
||||
class CustomCondition(EntityStateAttributeConditionBase):
|
||||
"""Condition for entity attribute."""
|
||||
|
||||
_domain = domain
|
||||
_domain_specs = {domain: DomainSpec()}
|
||||
_attribute = attribute
|
||||
_attribute_states = attribute_states_set
|
||||
|
||||
|
||||
@@ -169,6 +169,16 @@ def get_device_class(hass: HomeAssistant, entity_id: str) -> str | None:
|
||||
return entry.device_class or entry.original_device_class
|
||||
|
||||
|
||||
def get_device_class_or_undefined(
|
||||
hass: HomeAssistant, entity_id: str
|
||||
) -> str | None | UndefinedType:
|
||||
"""Get the device class of an entity or UNDEFINED if not found."""
|
||||
try:
|
||||
return get_device_class(hass, entity_id)
|
||||
except HomeAssistantError:
|
||||
return UNDEFINED
|
||||
|
||||
|
||||
def get_supported_features(hass: HomeAssistant, entity_id: str) -> int:
|
||||
"""Get supported features for an entity.
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from collections.abc import Callable, Coroutine, Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
import functools
|
||||
@@ -69,11 +69,13 @@ from homeassistant.util.yaml import load_yaml_dict
|
||||
|
||||
from . import config_validation as cv, selector
|
||||
from .automation import (
|
||||
DomainSpec,
|
||||
NumericalDomainSpec,
|
||||
filter_by_domain_specs,
|
||||
get_absolute_description_key,
|
||||
get_relative_description_key,
|
||||
move_options_fields_to_top_level,
|
||||
)
|
||||
from .entity import get_device_class
|
||||
from .integration_platform import async_process_integration_platforms
|
||||
from .selector import TargetSelector
|
||||
from .target import (
|
||||
@@ -81,7 +83,7 @@ from .target import (
|
||||
async_track_target_selector_state_change_event,
|
||||
)
|
||||
from .template import Template
|
||||
from .typing import UNDEFINED, ConfigType, TemplateVarsType, UndefinedType
|
||||
from .typing import ConfigType, TemplateVarsType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -334,20 +336,10 @@ ENTITY_STATE_TRIGGER_SCHEMA_FIRST_LAST = ENTITY_STATE_TRIGGER_SCHEMA.extend(
|
||||
)
|
||||
|
||||
|
||||
def get_device_class_or_undefined(
|
||||
hass: HomeAssistant, entity_id: str
|
||||
) -> str | None | UndefinedType:
|
||||
"""Get the device class of an entity or UNDEFINED if not found."""
|
||||
try:
|
||||
return get_device_class(hass, entity_id)
|
||||
except HomeAssistantError:
|
||||
return UNDEFINED
|
||||
|
||||
|
||||
class EntityTriggerBase(Trigger):
|
||||
class EntityTriggerBase[DomainSpecT: DomainSpec = DomainSpec](Trigger):
|
||||
"""Trigger for entity state changes."""
|
||||
|
||||
_domains: set[str]
|
||||
_domain_specs: Mapping[str, DomainSpecT]
|
||||
_schema: vol.Schema = ENTITY_STATE_TRIGGER_SCHEMA_FIRST_LAST
|
||||
|
||||
@override
|
||||
@@ -366,6 +358,10 @@ class EntityTriggerBase(Trigger):
|
||||
self._options = config.options or {}
|
||||
self._target = config.target
|
||||
|
||||
def entity_filter(self, entities: set[str]) -> set[str]:
|
||||
"""Filter entities matching any of the domain specs."""
|
||||
return filter_by_domain_specs(self._hass, self._domain_specs, entities)
|
||||
|
||||
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
|
||||
"""Check if the origin state is valid and the state has changed."""
|
||||
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
|
||||
@@ -396,14 +392,6 @@ class EntityTriggerBase(Trigger):
|
||||
== 1
|
||||
)
|
||||
|
||||
def entity_filter(self, entities: set[str]) -> set[str]:
|
||||
"""Filter entities of these domains."""
|
||||
return {
|
||||
entity_id
|
||||
for entity_id in entities
|
||||
if split_entity_id(entity_id)[0] in self._domains
|
||||
}
|
||||
|
||||
@override
|
||||
async def async_attach_runner(
|
||||
self, run_action: TriggerActionRunner
|
||||
@@ -611,19 +599,22 @@ def _get_numerical_value(
|
||||
return entity_or_float
|
||||
|
||||
|
||||
class EntityNumericalStateBase(EntityTriggerBase):
|
||||
class EntityNumericalStateBase(EntityTriggerBase[NumericalDomainSpec]):
|
||||
"""Base class for numerical state and state attribute triggers."""
|
||||
|
||||
_attributes: dict[str, str | None]
|
||||
_converter: Callable[[Any], float] = float
|
||||
|
||||
def _get_tracked_value(self, state: State) -> Any:
|
||||
"""Get the tracked numerical value from a state."""
|
||||
domain = split_entity_id(state.entity_id)[0]
|
||||
source = self._attributes[domain]
|
||||
if source is None:
|
||||
domain_spec = self._domain_specs[split_entity_id(state.entity_id)[0]]
|
||||
if domain_spec.value_source is None:
|
||||
return state.state
|
||||
return state.attributes.get(source)
|
||||
return state.attributes.get(domain_spec.value_source)
|
||||
|
||||
def _get_converter(self, state: State) -> Callable[[Any], float]:
|
||||
"""Get the value converter for an entity."""
|
||||
domain_spec = self._domain_specs[split_entity_id(state.entity_id)[0]]
|
||||
if domain_spec.value_converter is not None:
|
||||
return domain_spec.value_converter
|
||||
return float
|
||||
|
||||
|
||||
class EntityNumericalStateAttributeChangedTriggerBase(EntityNumericalStateBase):
|
||||
@@ -654,7 +645,7 @@ class EntityNumericalStateAttributeChangedTriggerBase(EntityNumericalStateBase):
|
||||
return False
|
||||
|
||||
try:
|
||||
current_value = self._converter(_attribute_value)
|
||||
current_value = self._get_converter(state)(_attribute_value)
|
||||
except TypeError, ValueError:
|
||||
# Value is not a valid number, don't trigger
|
||||
return False
|
||||
@@ -780,7 +771,7 @@ class EntityNumericalStateAttributeCrossedThresholdTriggerBase(
|
||||
return False
|
||||
|
||||
try:
|
||||
current_value = self._converter(_attribute_value)
|
||||
current_value = self._get_converter(state)(_attribute_value)
|
||||
except TypeError, ValueError:
|
||||
# Value is not a valid number, don't trigger
|
||||
return False
|
||||
@@ -812,7 +803,7 @@ def make_entity_target_state_trigger(
|
||||
class CustomTrigger(EntityTargetStateTriggerBase):
|
||||
"""Trigger for entity state changes."""
|
||||
|
||||
_domains = {domain}
|
||||
_domain_specs = {domain: DomainSpec()}
|
||||
_to_states = to_states_set
|
||||
|
||||
return CustomTrigger
|
||||
@@ -826,7 +817,7 @@ def make_entity_transition_trigger(
|
||||
class CustomTrigger(EntityTransitionTriggerBase):
|
||||
"""Trigger for conditional entity state changes."""
|
||||
|
||||
_domains = {domain}
|
||||
_domain_specs = {domain: DomainSpec()}
|
||||
_from_states = from_states
|
||||
_to_states = to_states
|
||||
|
||||
@@ -841,36 +832,34 @@ def make_entity_origin_state_trigger(
|
||||
class CustomTrigger(EntityOriginStateTriggerBase):
|
||||
"""Trigger for entity "from state" changes."""
|
||||
|
||||
_domains = {domain}
|
||||
_domain_specs = {domain: DomainSpec()}
|
||||
_from_state = from_state
|
||||
|
||||
return CustomTrigger
|
||||
|
||||
|
||||
def make_entity_numerical_state_attribute_changed_trigger(
|
||||
domains: set[str], attributes: dict[str, str | None]
|
||||
def make_entity_numerical_state_changed_trigger(
|
||||
domain_specs: Mapping[str, NumericalDomainSpec],
|
||||
) -> type[EntityNumericalStateAttributeChangedTriggerBase]:
|
||||
"""Create a trigger for numerical state attribute change."""
|
||||
"""Create a trigger for numerical state value change."""
|
||||
|
||||
class CustomTrigger(EntityNumericalStateAttributeChangedTriggerBase):
|
||||
"""Trigger for numerical state attribute changes."""
|
||||
"""Trigger for numerical state value changes."""
|
||||
|
||||
_domains = domains
|
||||
_attributes = attributes
|
||||
_domain_specs = domain_specs
|
||||
|
||||
return CustomTrigger
|
||||
|
||||
|
||||
def make_entity_numerical_state_attribute_crossed_threshold_trigger(
|
||||
domains: set[str], attributes: dict[str, str | None]
|
||||
def make_entity_numerical_state_crossed_threshold_trigger(
|
||||
domain_specs: Mapping[str, NumericalDomainSpec],
|
||||
) -> type[EntityNumericalStateAttributeCrossedThresholdTriggerBase]:
|
||||
"""Create a trigger for numerical state attribute change."""
|
||||
"""Create a trigger for numerical state value crossing a threshold."""
|
||||
|
||||
class CustomTrigger(EntityNumericalStateAttributeCrossedThresholdTriggerBase):
|
||||
"""Trigger for numerical state attribute changes."""
|
||||
"""Trigger for numerical state value crossing a threshold."""
|
||||
|
||||
_domains = domains
|
||||
_attributes = attributes
|
||||
_domain_specs = domain_specs
|
||||
|
||||
return CustomTrigger
|
||||
|
||||
@@ -883,7 +872,7 @@ def make_entity_target_state_attribute_trigger(
|
||||
class CustomTrigger(EntityTargetStateAttributeTriggerBase):
|
||||
"""Trigger for entity state changes."""
|
||||
|
||||
_domains = {domain}
|
||||
_domain_specs = {domain: DomainSpec()}
|
||||
_attribute = attribute
|
||||
_attribute_to_state = to_state
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""The tests for the trigger helper."""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from contextlib import AbstractContextManager, nullcontext as does_not_raise
|
||||
import io
|
||||
from typing import Any
|
||||
@@ -15,6 +16,7 @@ from homeassistant.components.system_health import DOMAIN as SYSTEM_HEALTH_DOMAI
|
||||
from homeassistant.components.tag import DOMAIN as TAG_DOMAIN
|
||||
from homeassistant.components.text import DOMAIN as TEXT_DOMAIN
|
||||
from homeassistant.const import (
|
||||
ATTR_DEVICE_CLASS,
|
||||
CONF_ABOVE,
|
||||
CONF_BELOW,
|
||||
CONF_ENTITY_ID,
|
||||
@@ -33,20 +35,27 @@ from homeassistant.core import (
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv, trigger
|
||||
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
||||
from homeassistant.helpers.automation import (
|
||||
ANY_DEVICE_CLASS,
|
||||
DomainSpec,
|
||||
NumericalDomainSpec,
|
||||
move_top_level_schema_fields_to_options,
|
||||
)
|
||||
from homeassistant.helpers.trigger import (
|
||||
CONF_LOWER_LIMIT,
|
||||
CONF_THRESHOLD_TYPE,
|
||||
CONF_UPPER_LIMIT,
|
||||
DATA_PLUGGABLE_ACTIONS,
|
||||
EntityTriggerBase,
|
||||
PluggableAction,
|
||||
Trigger,
|
||||
TriggerActionRunner,
|
||||
TriggerConfig,
|
||||
_async_get_trigger_platform,
|
||||
async_initialize_triggers,
|
||||
async_validate_trigger_config,
|
||||
make_entity_numerical_state_attribute_changed_trigger,
|
||||
make_entity_numerical_state_attribute_crossed_threshold_trigger,
|
||||
make_entity_numerical_state_changed_trigger,
|
||||
make_entity_numerical_state_crossed_threshold_trigger,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import Integration, async_get_integration
|
||||
@@ -1242,8 +1251,8 @@ async def test_numerical_state_attribute_changed_trigger_config_validation(
|
||||
|
||||
async def async_get_triggers(hass: HomeAssistant) -> dict[str, type[Trigger]]:
|
||||
return {
|
||||
"test_trigger": make_entity_numerical_state_attribute_changed_trigger(
|
||||
{"test"}, {"test": "test_attribute"}
|
||||
"test_trigger": make_entity_numerical_state_changed_trigger(
|
||||
{"test": NumericalDomainSpec(value_source="test_attribute")}
|
||||
),
|
||||
}
|
||||
|
||||
@@ -1270,8 +1279,8 @@ async def test_numerical_state_attribute_changed_error_handling(
|
||||
|
||||
async def async_get_triggers(hass: HomeAssistant) -> dict[str, type[Trigger]]:
|
||||
return {
|
||||
"attribute_changed": make_entity_numerical_state_attribute_changed_trigger(
|
||||
{"test"}, {"test": "test_attribute"}
|
||||
"attribute_changed": make_entity_numerical_state_changed_trigger(
|
||||
{"test": NumericalDomainSpec(value_source="test_attribute")}
|
||||
),
|
||||
}
|
||||
|
||||
@@ -1552,8 +1561,8 @@ async def test_numerical_state_attribute_crossed_threshold_trigger_config_valida
|
||||
|
||||
async def async_get_triggers(hass: HomeAssistant) -> dict[str, type[Trigger]]:
|
||||
return {
|
||||
"test_trigger": make_entity_numerical_state_attribute_crossed_threshold_trigger(
|
||||
{"test"}, {"test": "test_attribute"}
|
||||
"test_trigger": make_entity_numerical_state_crossed_threshold_trigger(
|
||||
{"test": NumericalDomainSpec(value_source="test_attribute")}
|
||||
),
|
||||
}
|
||||
|
||||
@@ -1571,3 +1580,126 @@ async def test_numerical_state_attribute_crossed_threshold_trigger_config_valida
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _make_trigger(
|
||||
hass: HomeAssistant, domain_specs: Mapping[str, DomainSpec]
|
||||
) -> EntityTriggerBase:
|
||||
"""Create a minimal EntityTriggerBase subclass with the given domain specs."""
|
||||
|
||||
class _SimpleTrigger(EntityTriggerBase):
|
||||
"""Minimal concrete trigger for testing entity_filter."""
|
||||
|
||||
_domain_specs = domain_specs
|
||||
|
||||
def is_valid_state(self, state):
|
||||
"""Accept any state."""
|
||||
return True
|
||||
|
||||
config = TriggerConfig(key="test.test_trigger", target={CONF_ENTITY_ID: []})
|
||||
return _SimpleTrigger(hass, config)
|
||||
|
||||
|
||||
async def test_entity_filter_by_domain_only(hass: HomeAssistant) -> None:
|
||||
"""Test entity_filter includes entities matching domain, excludes others."""
|
||||
trig = _make_trigger(hass, {"sensor": DomainSpec(), "switch": DomainSpec()})
|
||||
|
||||
entities = {
|
||||
"sensor.temp",
|
||||
"sensor.humidity",
|
||||
"switch.light",
|
||||
"light.bedroom",
|
||||
"cover.garage",
|
||||
}
|
||||
result = trig.entity_filter(entities)
|
||||
assert result == {"sensor.temp", "sensor.humidity", "switch.light"}
|
||||
|
||||
|
||||
async def test_entity_filter_by_device_class(hass: HomeAssistant) -> None:
|
||||
"""Test entity_filter filters by device_class when specified."""
|
||||
trig = _make_trigger(hass, {"sensor": DomainSpec(device_class="humidity")})
|
||||
|
||||
# Set states with device_class attributes
|
||||
hass.states.async_set("sensor.humidity_1", "50", {ATTR_DEVICE_CLASS: "humidity"})
|
||||
hass.states.async_set(
|
||||
"sensor.temperature_1", "22", {ATTR_DEVICE_CLASS: "temperature"}
|
||||
)
|
||||
hass.states.async_set("sensor.no_class", "10", {})
|
||||
|
||||
entities = {"sensor.humidity_1", "sensor.temperature_1", "sensor.no_class"}
|
||||
result = trig.entity_filter(entities)
|
||||
assert result == {"sensor.humidity_1"}
|
||||
|
||||
|
||||
async def test_entity_filter_device_class_unknown_entity(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test entity_filter excludes entities not in state machine or registry."""
|
||||
trig = _make_trigger(hass, {"sensor": DomainSpec(device_class="humidity")})
|
||||
|
||||
# Entity not in state machine and not in entity registry -> UNDEFINED
|
||||
entities = {"sensor.nonexistent"}
|
||||
result = trig.entity_filter(entities)
|
||||
assert result == set()
|
||||
|
||||
|
||||
async def test_entity_filter_multiple_domains_with_device_class(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test entity_filter with multiple domains, some with device_class filtering."""
|
||||
trig = _make_trigger(
|
||||
hass,
|
||||
{
|
||||
"climate": DomainSpec(value_source="current_humidity"),
|
||||
"sensor": DomainSpec(device_class="humidity"),
|
||||
"weather": DomainSpec(value_source="humidity"),
|
||||
},
|
||||
)
|
||||
|
||||
hass.states.async_set("sensor.humidity", "60", {ATTR_DEVICE_CLASS: "humidity"})
|
||||
hass.states.async_set(
|
||||
"sensor.temperature", "20", {ATTR_DEVICE_CLASS: "temperature"}
|
||||
)
|
||||
hass.states.async_set("climate.hvac", "heat", {})
|
||||
hass.states.async_set("weather.home", "sunny", {})
|
||||
hass.states.async_set("light.bedroom", "on", {})
|
||||
|
||||
entities = {
|
||||
"sensor.humidity",
|
||||
"sensor.temperature",
|
||||
"climate.hvac",
|
||||
"weather.home",
|
||||
"light.bedroom",
|
||||
}
|
||||
result = trig.entity_filter(entities)
|
||||
# sensor.temperature excluded (wrong device_class)
|
||||
# light.bedroom excluded (no matching domain)
|
||||
assert result == {"sensor.humidity", "climate.hvac", "weather.home"}
|
||||
|
||||
|
||||
async def test_entity_filter_no_device_class_means_match_all_in_domain(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test that DomainSpec without device_class matches all entities in the domain."""
|
||||
trig = _make_trigger(hass, {"cover": DomainSpec()})
|
||||
|
||||
hass.states.async_set("cover.door", "open", {ATTR_DEVICE_CLASS: "door"})
|
||||
hass.states.async_set("cover.garage", "closed", {ATTR_DEVICE_CLASS: "garage"})
|
||||
hass.states.async_set("cover.plain", "open", {})
|
||||
|
||||
entities = {"cover.door", "cover.garage", "cover.plain"}
|
||||
result = trig.entity_filter(entities)
|
||||
assert result == entities
|
||||
|
||||
|
||||
async def test_numerical_domain_spec_converter(hass: HomeAssistant) -> None:
|
||||
"""Test NumericalDomainSpec stores converter correctly."""
|
||||
converter = lambda v: float(v) / 255.0 * 100.0 # noqa: E731
|
||||
nvs = NumericalDomainSpec(value_source="brightness", value_converter=converter)
|
||||
assert nvs.value_source == "brightness"
|
||||
assert nvs.value_converter is converter
|
||||
assert nvs.device_class is ANY_DEVICE_CLASS
|
||||
|
||||
# Plain DomainSpec has no converter
|
||||
vs = DomainSpec(value_source="brightness")
|
||||
assert not isinstance(vs, NumericalDomainSpec)
|
||||
|
||||
Reference in New Issue
Block a user