1
0
mirror of https://github.com/home-assistant/core.git synced 2026-04-17 23:53:49 +01:00

Add DomainSpec to trigger and condition helpers (#165392)

This commit is contained in:
Ariel Ebersberger
2026-03-13 19:50:19 +01:00
committed by GitHub
parent d96191723f
commit eb17367229
20 changed files with 365 additions and 211 deletions

View File

@@ -2,6 +2,7 @@
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.condition import (
Condition,
EntityStateConditionBase,
@@ -43,7 +44,7 @@ def make_entity_state_required_features_condition(
class CustomCondition(EntityStateRequiredFeaturesCondition):
"""Condition for entity state changes."""
_domain = domain
_domain_specs = {domain: DomainSpec()}
_states = {to_state}
_required_features = required_features

View File

@@ -2,6 +2,7 @@
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.entity import get_supported_features
from homeassistant.helpers.trigger import (
EntityTargetStateTriggerBase,
@@ -44,7 +45,7 @@ def make_entity_state_trigger_required_features(
class CustomTrigger(EntityStateTriggerRequiredFeatures):
"""Trigger for entity state changes."""
_domains = {domain}
_domain_specs = {domain: DomainSpec()}
_to_states = {to_state}
_required_features = required_features

View File

@@ -2,6 +2,7 @@
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import (
ENTITY_STATE_TRIGGER_SCHEMA,
EntityTriggerBase,
@@ -14,7 +15,7 @@ from . import DOMAIN
class ButtonPressedTrigger(EntityTriggerBase):
"""Trigger for button entity presses."""
_domains = {DOMAIN}
_domain_specs = {DOMAIN: DomainSpec()}
_schema = ENTITY_STATE_TRIGGER_SCHEMA
def is_valid_transition(self, from_state: State, to_state: State) -> bool:

View File

@@ -5,13 +5,14 @@ import voluptuous as vol
from homeassistant.const import ATTR_TEMPERATURE, CONF_OPTIONS
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.automation import DomainSpec, NumericalDomainSpec
from homeassistant.helpers.trigger import (
ENTITY_STATE_TRIGGER_SCHEMA_FIRST_LAST,
EntityTargetStateTriggerBase,
Trigger,
TriggerConfig,
make_entity_numerical_state_attribute_changed_trigger,
make_entity_numerical_state_attribute_crossed_threshold_trigger,
make_entity_numerical_state_changed_trigger,
make_entity_numerical_state_crossed_threshold_trigger,
make_entity_target_state_attribute_trigger,
make_entity_target_state_trigger,
make_entity_transition_trigger,
@@ -35,7 +36,7 @@ HVAC_MODE_CHANGED_TRIGGER_SCHEMA = ENTITY_STATE_TRIGGER_SCHEMA_FIRST_LAST.extend
class HVACModeChangedTrigger(EntityTargetStateTriggerBase):
"""Trigger for entity state changes."""
_domains = {DOMAIN}
_domain_specs = {DOMAIN: DomainSpec()}
_schema = HVAC_MODE_CHANGED_TRIGGER_SCHEMA
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
@@ -52,17 +53,17 @@ TRIGGERS: dict[str, type[Trigger]] = {
"started_drying": make_entity_target_state_attribute_trigger(
DOMAIN, ATTR_HVAC_ACTION, HVACAction.DRYING
),
"target_humidity_changed": make_entity_numerical_state_attribute_changed_trigger(
{DOMAIN}, {DOMAIN: ATTR_HUMIDITY}
"target_humidity_changed": make_entity_numerical_state_changed_trigger(
{DOMAIN: NumericalDomainSpec(value_source=ATTR_HUMIDITY)}
),
"target_humidity_crossed_threshold": make_entity_numerical_state_attribute_crossed_threshold_trigger(
{DOMAIN}, {DOMAIN: ATTR_HUMIDITY}
"target_humidity_crossed_threshold": make_entity_numerical_state_crossed_threshold_trigger(
{DOMAIN: NumericalDomainSpec(value_source=ATTR_HUMIDITY)}
),
"target_temperature_changed": make_entity_numerical_state_attribute_changed_trigger(
{DOMAIN}, {DOMAIN: ATTR_TEMPERATURE}
"target_temperature_changed": make_entity_numerical_state_changed_trigger(
{DOMAIN: NumericalDomainSpec(value_source=ATTR_TEMPERATURE)}
),
"target_temperature_crossed_threshold": make_entity_numerical_state_attribute_crossed_threshold_trigger(
{DOMAIN}, {DOMAIN: ATTR_TEMPERATURE}
"target_temperature_crossed_threshold": make_entity_numerical_state_crossed_threshold_trigger(
{DOMAIN: NumericalDomainSpec(value_source=ATTR_TEMPERATURE)}
),
"turned_off": make_entity_target_state_trigger(DOMAIN, HVACMode.OFF),
"turned_on": make_entity_transition_trigger(

View File

@@ -1,81 +1,82 @@
"""Provides triggers for covers."""
from dataclasses import dataclass
from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, State, split_entity_id
from homeassistant.helpers.trigger import (
EntityTriggerBase,
Trigger,
get_device_class_or_undefined,
)
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import EntityTriggerBase, Trigger
from .const import ATTR_IS_CLOSED, DOMAIN, CoverDeviceClass
class CoverTriggerBase(EntityTriggerBase):
@dataclass(frozen=True, slots=True)
class CoverDomainSpec(DomainSpec):
"""DomainSpec with a target value for comparison."""
target_value: str | bool | None = None
class CoverTriggerBase(EntityTriggerBase[CoverDomainSpec]):
"""Base trigger for cover state changes."""
_binary_sensor_target_state: str
_cover_is_closed_target_value: bool
_device_classes: dict[str, str]
def entity_filter(self, entities: set[str]) -> set[str]:
"""Filter entities by cover device class."""
entities = super().entity_filter(entities)
return {
entity_id
for entity_id in entities
if get_device_class_or_undefined(self._hass, entity_id)
== self._device_classes[split_entity_id(entity_id)[0]]
}
def _get_value(self, state: State) -> str | bool | None:
"""Extract the relevant value from state based on domain spec."""
domain_spec = self._domain_specs[split_entity_id(state.entity_id)[0]]
if domain_spec.value_source is not None:
return state.attributes.get(domain_spec.value_source)
return state.state
def is_valid_state(self, state: State) -> bool:
"""Check if the state matches the target cover state."""
if split_entity_id(state.entity_id)[0] == DOMAIN:
return (
state.attributes.get(ATTR_IS_CLOSED)
== self._cover_is_closed_target_value
)
return state.state == self._binary_sensor_target_state
domain_spec = self._domain_specs[split_entity_id(state.entity_id)[0]]
return self._get_value(state) == domain_spec.target_value
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the transition is valid for a cover state change."""
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
return False
if split_entity_id(from_state.entity_id)[0] == DOMAIN:
if (from_is_closed := from_state.attributes.get(ATTR_IS_CLOSED)) is None:
return False
return from_is_closed != to_state.attributes.get(ATTR_IS_CLOSED) # type: ignore[no-any-return]
return from_state.state != to_state.state
if (from_value := self._get_value(from_state)) is None:
return False
return from_value != self._get_value(to_state)
def make_cover_opened_trigger(
*, device_classes: dict[str, str], domains: set[str] | None = None
*, device_classes: dict[str, str]
) -> type[CoverTriggerBase]:
"""Create a trigger cover_opened."""
class CoverOpenedTrigger(CoverTriggerBase):
"""Trigger for cover opened state changes."""
_binary_sensor_target_state = STATE_ON
_cover_is_closed_target_value = False
_domains = domains or {DOMAIN}
_device_classes = device_classes
_domain_specs = {
domain: CoverDomainSpec(
device_class=dc,
value_source=ATTR_IS_CLOSED if domain == DOMAIN else None,
target_value=False if domain == DOMAIN else STATE_ON,
)
for domain, dc in device_classes.items()
}
return CoverOpenedTrigger
def make_cover_closed_trigger(
*, device_classes: dict[str, str], domains: set[str] | None = None
*, device_classes: dict[str, str]
) -> type[CoverTriggerBase]:
"""Create a trigger cover_closed."""
class CoverClosedTrigger(CoverTriggerBase):
"""Trigger for cover closed state changes."""
_binary_sensor_target_state = STATE_OFF
_cover_is_closed_target_value = True
_domains = domains or {DOMAIN}
_device_classes = device_classes
_domain_specs = {
domain: CoverDomainSpec(
device_class=dc,
value_source=ATTR_IS_CLOSED if domain == DOMAIN else None,
target_value=True if domain == DOMAIN else STATE_OFF,
)
for domain, dc in device_classes.items()
}
return CoverClosedTrigger

View File

@@ -20,14 +20,8 @@ DEVICE_CLASSES_DOOR: dict[str, str] = {
TRIGGERS: dict[str, type[Trigger]] = {
"opened": make_cover_opened_trigger(
device_classes=DEVICE_CLASSES_DOOR,
domains={BINARY_SENSOR_DOMAIN, COVER_DOMAIN},
),
"closed": make_cover_closed_trigger(
device_classes=DEVICE_CLASSES_DOOR,
domains={BINARY_SENSOR_DOMAIN, COVER_DOMAIN},
),
"opened": make_cover_opened_trigger(device_classes=DEVICE_CLASSES_DOOR),
"closed": make_cover_closed_trigger(device_classes=DEVICE_CLASSES_DOOR),
}

View File

@@ -20,14 +20,8 @@ DEVICE_CLASSES_GARAGE_DOOR: dict[str, str] = {
TRIGGERS: dict[str, type[Trigger]] = {
"opened": make_cover_opened_trigger(
device_classes=DEVICE_CLASSES_GARAGE_DOOR,
domains={BINARY_SENSOR_DOMAIN, COVER_DOMAIN},
),
"closed": make_cover_closed_trigger(
device_classes=DEVICE_CLASSES_GARAGE_DOOR,
domains={BINARY_SENSOR_DOMAIN, COVER_DOMAIN},
),
"opened": make_cover_opened_trigger(device_classes=DEVICE_CLASSES_GARAGE_DOOR),
"closed": make_cover_closed_trigger(device_classes=DEVICE_CLASSES_GARAGE_DOOR),
}

View File

@@ -15,50 +15,43 @@ from homeassistant.components.weather import (
ATTR_WEATHER_HUMIDITY,
DOMAIN as WEATHER_DOMAIN,
)
from homeassistant.core import HomeAssistant, split_entity_id
from homeassistant.core import HomeAssistant
from homeassistant.helpers.automation import NumericalDomainSpec
from homeassistant.helpers.trigger import (
EntityNumericalStateAttributeChangedTriggerBase,
EntityNumericalStateAttributeCrossedThresholdTriggerBase,
EntityTriggerBase,
Trigger,
get_device_class_or_undefined,
)
class _HumidityTriggerMixin(EntityTriggerBase):
"""Mixin for humidity triggers providing entity filtering and value extraction."""
_attributes = {
CLIMATE_DOMAIN: CLIMATE_ATTR_CURRENT_HUMIDITY,
HUMIDIFIER_DOMAIN: HUMIDIFIER_ATTR_CURRENT_HUMIDITY,
SENSOR_DOMAIN: None, # Use state.state
WEATHER_DOMAIN: ATTR_WEATHER_HUMIDITY,
}
_domains = {SENSOR_DOMAIN, CLIMATE_DOMAIN, HUMIDIFIER_DOMAIN, WEATHER_DOMAIN}
def entity_filter(self, entities: set[str]) -> set[str]:
"""Filter entities: all climate/humidifier/weather, sensor only with device_class humidity."""
entities = super().entity_filter(entities)
return {
entity_id
for entity_id in entities
if split_entity_id(entity_id)[0] != SENSOR_DOMAIN
or get_device_class_or_undefined(self._hass, entity_id)
== SensorDeviceClass.HUMIDITY
}
HUMIDITY_DOMAIN_SPECS: dict[str, NumericalDomainSpec] = {
CLIMATE_DOMAIN: NumericalDomainSpec(
value_source=CLIMATE_ATTR_CURRENT_HUMIDITY,
),
HUMIDIFIER_DOMAIN: NumericalDomainSpec(
value_source=HUMIDIFIER_ATTR_CURRENT_HUMIDITY,
),
SENSOR_DOMAIN: NumericalDomainSpec(
device_class=SensorDeviceClass.HUMIDITY,
),
WEATHER_DOMAIN: NumericalDomainSpec(
value_source=ATTR_WEATHER_HUMIDITY,
),
}
class HumidityChangedTrigger(
_HumidityTriggerMixin, EntityNumericalStateAttributeChangedTriggerBase
):
class HumidityChangedTrigger(EntityNumericalStateAttributeChangedTriggerBase):
"""Trigger for humidity value changes across multiple domains."""
_domain_specs = HUMIDITY_DOMAIN_SPECS
class HumidityCrossedThresholdTrigger(
_HumidityTriggerMixin, EntityNumericalStateAttributeCrossedThresholdTriggerBase
EntityNumericalStateAttributeCrossedThresholdTriggerBase
):
"""Trigger for humidity value crossing a threshold across multiple domains."""
_domain_specs = HUMIDITY_DOMAIN_SPECS
TRIGGERS: dict[str, type[Trigger]] = {
"changed": HumidityChangedTrigger,

View File

@@ -4,6 +4,7 @@ from typing import Any
from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant
from homeassistant.helpers.automation import NumericalDomainSpec
from homeassistant.helpers.trigger import (
EntityNumericalStateAttributeChangedTriggerBase,
EntityNumericalStateAttributeCrossedThresholdTriggerBase,
@@ -20,13 +21,18 @@ def _convert_uint8_to_percentage(value: Any) -> float:
return (float(value) / 255.0) * 100.0
BRIGHTNESS_DOMAIN_SPECS = {
DOMAIN: NumericalDomainSpec(
value_source=ATTR_BRIGHTNESS,
value_converter=_convert_uint8_to_percentage,
),
}
class BrightnessChangedTrigger(EntityNumericalStateAttributeChangedTriggerBase):
"""Trigger for brightness changed."""
_domains = {DOMAIN}
_attributes = {DOMAIN: ATTR_BRIGHTNESS}
_converter = staticmethod(_convert_uint8_to_percentage)
_domain_specs = BRIGHTNESS_DOMAIN_SPECS
class BrightnessCrossedThresholdTrigger(
@@ -34,9 +40,7 @@ class BrightnessCrossedThresholdTrigger(
):
"""Trigger for brightness crossed threshold."""
_domains = {DOMAIN}
_attributes = {DOMAIN: ATTR_BRIGHTNESS}
_converter = staticmethod(_convert_uint8_to_percentage)
_domain_specs = BRIGHTNESS_DOMAIN_SPECS
TRIGGERS: dict[str, type[Trigger]] = {

View File

@@ -6,28 +6,20 @@ from homeassistant.components.binary_sensor import (
)
from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import (
EntityTargetStateTriggerBase,
EntityTriggerBase,
Trigger,
get_device_class_or_undefined,
)
class _MotionBinaryTriggerBase(EntityTriggerBase):
"""Base trigger for motion binary sensor state changes."""
_domains = {BINARY_SENSOR_DOMAIN}
def entity_filter(self, entities: set[str]) -> set[str]:
"""Filter entities by motion device class."""
entities = super().entity_filter(entities)
return {
entity_id
for entity_id in entities
if get_device_class_or_undefined(self._hass, entity_id)
== BinarySensorDeviceClass.MOTION
}
_domain_specs = {
BINARY_SENSOR_DOMAIN: DomainSpec(device_class=BinarySensorDeviceClass.MOTION)
}
class MotionDetectedTrigger(_MotionBinaryTriggerBase, EntityTargetStateTriggerBase):

View File

@@ -6,28 +6,20 @@ from homeassistant.components.binary_sensor import (
)
from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import (
EntityTargetStateTriggerBase,
EntityTriggerBase,
Trigger,
get_device_class_or_undefined,
)
class _OccupancyBinaryTriggerBase(EntityTriggerBase):
"""Base trigger for occupancy binary sensor state changes."""
_domains = {BINARY_SENSOR_DOMAIN}
def entity_filter(self, entities: set[str]) -> set[str]:
"""Filter entities by occupancy device class."""
entities = super().entity_filter(entities)
return {
entity_id
for entity_id in entities
if get_device_class_or_undefined(self._hass, entity_id)
== BinarySensorDeviceClass.OCCUPANCY
}
_domain_specs = {
BINARY_SENSOR_DOMAIN: DomainSpec(device_class=BinarySensorDeviceClass.OCCUPANCY)
}
class OccupancyDetectedTrigger(

View File

@@ -2,6 +2,7 @@
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import (
ENTITY_STATE_TRIGGER_SCHEMA,
EntityTriggerBase,
@@ -14,7 +15,7 @@ from . import DOMAIN
class SceneActivatedTrigger(EntityTriggerBase):
"""Trigger for scene entity activations."""
_domains = {DOMAIN}
_domain_specs = {DOMAIN: DomainSpec()}
_schema = ENTITY_STATE_TRIGGER_SCHEMA
def is_valid_transition(self, from_state: State, to_state: State) -> bool:

View File

@@ -2,6 +2,7 @@
from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import (
EntityTransitionTriggerBase,
Trigger,
@@ -14,7 +15,7 @@ from .const import ATTR_NEXT_EVENT, DOMAIN
class ScheduleBackToBackTrigger(EntityTransitionTriggerBase):
"""Trigger for back-to-back schedule blocks."""
_domains = {DOMAIN}
_domain_specs = {DOMAIN: DomainSpec()}
_from_states = {STATE_OFF, STATE_ON}
_to_states = {STATE_ON}

View File

@@ -2,6 +2,7 @@
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import (
ENTITY_STATE_TRIGGER_SCHEMA,
EntityTriggerBase,
@@ -14,7 +15,7 @@ from .const import DOMAIN
class TextChangedTrigger(EntityTriggerBase):
"""Trigger for text entity when its content changes."""
_domains = {DOMAIN}
_domain_specs = {DOMAIN: DomainSpec()}
_schema = ENTITY_STATE_TRIGGER_SCHEMA
def is_valid_state(self, state: State) -> bool:

View File

@@ -20,14 +20,8 @@ DEVICE_CLASSES_WINDOW: dict[str, str] = {
TRIGGERS: dict[str, type[Trigger]] = {
"opened": make_cover_opened_trigger(
device_classes=DEVICE_CLASSES_WINDOW,
domains={BINARY_SENSOR_DOMAIN, COVER_DOMAIN},
),
"closed": make_cover_closed_trigger(
device_classes=DEVICE_CLASSES_WINDOW,
domains={BINARY_SENSOR_DOMAIN, COVER_DOMAIN},
),
"opened": make_cover_opened_trigger(device_classes=DEVICE_CLASSES_WINDOW),
"closed": make_cover_closed_trigger(device_classes=DEVICE_CLASSES_WINDOW),
}

View File

@@ -1,14 +1,68 @@
"""Helpers for automation."""
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from enum import Enum
from typing import Any
import voluptuous as vol
from homeassistant.const import CONF_OPTIONS
from homeassistant.core import HomeAssistant, split_entity_id
from .entity import get_device_class_or_undefined
from .typing import ConfigType
class AnyDeviceClassType(Enum):
"""Singleton type for matching any device class."""
_singleton = 0
ANY_DEVICE_CLASS = AnyDeviceClassType._singleton # noqa: SLF001
@dataclass(frozen=True, slots=True)
class DomainSpec:
"""Describes how to match and extract a value from an entity.
Used by triggers and conditions.
"""
device_class: str | None | AnyDeviceClassType = ANY_DEVICE_CLASS
value_source: str | None = None
"""Attribute name to extract the value from, or None for state.state."""
@dataclass(frozen=True, slots=True)
class NumericalDomainSpec(DomainSpec):
"""DomainSpec with an optional value converter for numerical triggers."""
value_converter: Callable[[Any], float] | None = None
"""Optional converter for numerical values (e.g. uint8 → percentage)."""
def filter_by_domain_specs(
hass: HomeAssistant,
domain_specs: Mapping[str, DomainSpec],
entities: set[str],
) -> set[str]:
"""Filter entities matching any of the domain specs."""
result: set[str] = set()
for entity_id in entities:
if not (domain_spec := domain_specs.get(split_entity_id(entity_id)[0])):
continue
if (
domain_spec.device_class is not ANY_DEVICE_CLASS
and get_device_class_or_undefined(hass, entity_id)
!= domain_spec.device_class
):
continue
result.add(entity_id)
return result
def get_absolute_description_key(domain: str, key: str) -> str:
"""Return the absolute description key."""
if not key.startswith("_"):

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import abc
from collections import deque
from collections.abc import Callable, Container, Coroutine, Generator, Iterable
from collections.abc import Callable, Container, Coroutine, Generator, Iterable, Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime, time as dt_time, timedelta
@@ -54,7 +54,7 @@ from homeassistant.const import (
STATE_UNKNOWN,
WEEKDAYS,
)
from homeassistant.core import HomeAssistant, State, callback, split_entity_id
from homeassistant.core import HomeAssistant, State, callback
from homeassistant.exceptions import (
ConditionError,
ConditionErrorContainer,
@@ -76,6 +76,8 @@ from homeassistant.util.yaml import load_yaml_dict
from . import config_validation as cv, entity_registry as er, selector
from .automation import (
DomainSpec,
filter_by_domain_specs,
get_absolute_description_key,
get_relative_description_key,
move_options_fields_to_top_level,
@@ -332,10 +334,10 @@ ENTITY_STATE_CONDITION_SCHEMA_ANY_ALL = vol.Schema(
)
class EntityConditionBase(Condition):
class EntityConditionBase[DomainSpecT: DomainSpec = DomainSpec](Condition):
"""Base class for entity conditions."""
_domain: str
_domain_specs: Mapping[str, DomainSpecT]
_schema: vol.Schema = ENTITY_STATE_CONDITION_SCHEMA_ANY_ALL
@override
@@ -356,12 +358,8 @@ class EntityConditionBase(Condition):
self._behavior = config.options[ATTR_BEHAVIOR]
def entity_filter(self, entities: set[str]) -> set[str]:
"""Filter entities of this domain."""
return {
entity_id
for entity_id in entities
if split_entity_id(entity_id)[0] == self._domain
}
"""Filter entities matching any of the domain specs."""
return filter_by_domain_specs(self._hass, self._domain_specs, entities)
@abc.abstractmethod
def is_valid_state(self, entity_state: State) -> bool:
@@ -428,7 +426,7 @@ def make_entity_state_condition(
class CustomCondition(EntityStateConditionBase):
"""Condition for entity state."""
_domain = domain
_domain_specs = {domain: DomainSpec()}
_states = states_set
return CustomCondition
@@ -458,7 +456,7 @@ def make_entity_state_attribute_condition(
class CustomCondition(EntityStateAttributeConditionBase):
"""Condition for entity attribute."""
_domain = domain
_domain_specs = {domain: DomainSpec()}
_attribute = attribute
_attribute_states = attribute_states_set

View File

@@ -169,6 +169,16 @@ def get_device_class(hass: HomeAssistant, entity_id: str) -> str | None:
return entry.device_class or entry.original_device_class
def get_device_class_or_undefined(
hass: HomeAssistant, entity_id: str
) -> str | None | UndefinedType:
"""Get the device class of an entity or UNDEFINED if not found."""
try:
return get_device_class(hass, entity_id)
except HomeAssistantError:
return UNDEFINED
def get_supported_features(hass: HomeAssistant, entity_id: str) -> int:
"""Get supported features for an entity.

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import abc
import asyncio
from collections import defaultdict
from collections.abc import Callable, Coroutine, Iterable
from collections.abc import Callable, Coroutine, Iterable, Mapping
from dataclasses import dataclass, field
from enum import StrEnum
import functools
@@ -69,11 +69,13 @@ from homeassistant.util.yaml import load_yaml_dict
from . import config_validation as cv, selector
from .automation import (
DomainSpec,
NumericalDomainSpec,
filter_by_domain_specs,
get_absolute_description_key,
get_relative_description_key,
move_options_fields_to_top_level,
)
from .entity import get_device_class
from .integration_platform import async_process_integration_platforms
from .selector import TargetSelector
from .target import (
@@ -81,7 +83,7 @@ from .target import (
async_track_target_selector_state_change_event,
)
from .template import Template
from .typing import UNDEFINED, ConfigType, TemplateVarsType, UndefinedType
from .typing import ConfigType, TemplateVarsType
_LOGGER = logging.getLogger(__name__)
@@ -334,20 +336,10 @@ ENTITY_STATE_TRIGGER_SCHEMA_FIRST_LAST = ENTITY_STATE_TRIGGER_SCHEMA.extend(
)
def get_device_class_or_undefined(
hass: HomeAssistant, entity_id: str
) -> str | None | UndefinedType:
"""Get the device class of an entity or UNDEFINED if not found."""
try:
return get_device_class(hass, entity_id)
except HomeAssistantError:
return UNDEFINED
class EntityTriggerBase(Trigger):
class EntityTriggerBase[DomainSpecT: DomainSpec = DomainSpec](Trigger):
"""Trigger for entity state changes."""
_domains: set[str]
_domain_specs: Mapping[str, DomainSpecT]
_schema: vol.Schema = ENTITY_STATE_TRIGGER_SCHEMA_FIRST_LAST
@override
@@ -366,6 +358,10 @@ class EntityTriggerBase(Trigger):
self._options = config.options or {}
self._target = config.target
def entity_filter(self, entities: set[str]) -> set[str]:
"""Filter entities matching any of the domain specs."""
return filter_by_domain_specs(self._hass, self._domain_specs, entities)
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state is valid and the state has changed."""
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
@@ -396,14 +392,6 @@ class EntityTriggerBase(Trigger):
== 1
)
def entity_filter(self, entities: set[str]) -> set[str]:
"""Filter entities of these domains."""
return {
entity_id
for entity_id in entities
if split_entity_id(entity_id)[0] in self._domains
}
@override
async def async_attach_runner(
self, run_action: TriggerActionRunner
@@ -611,19 +599,22 @@ def _get_numerical_value(
return entity_or_float
class EntityNumericalStateBase(EntityTriggerBase):
class EntityNumericalStateBase(EntityTriggerBase[NumericalDomainSpec]):
"""Base class for numerical state and state attribute triggers."""
_attributes: dict[str, str | None]
_converter: Callable[[Any], float] = float
def _get_tracked_value(self, state: State) -> Any:
"""Get the tracked numerical value from a state."""
domain = split_entity_id(state.entity_id)[0]
source = self._attributes[domain]
if source is None:
domain_spec = self._domain_specs[split_entity_id(state.entity_id)[0]]
if domain_spec.value_source is None:
return state.state
return state.attributes.get(source)
return state.attributes.get(domain_spec.value_source)
def _get_converter(self, state: State) -> Callable[[Any], float]:
"""Get the value converter for an entity."""
domain_spec = self._domain_specs[split_entity_id(state.entity_id)[0]]
if domain_spec.value_converter is not None:
return domain_spec.value_converter
return float
class EntityNumericalStateAttributeChangedTriggerBase(EntityNumericalStateBase):
@@ -654,7 +645,7 @@ class EntityNumericalStateAttributeChangedTriggerBase(EntityNumericalStateBase):
return False
try:
current_value = self._converter(_attribute_value)
current_value = self._get_converter(state)(_attribute_value)
except TypeError, ValueError:
# Value is not a valid number, don't trigger
return False
@@ -780,7 +771,7 @@ class EntityNumericalStateAttributeCrossedThresholdTriggerBase(
return False
try:
current_value = self._converter(_attribute_value)
current_value = self._get_converter(state)(_attribute_value)
except TypeError, ValueError:
# Value is not a valid number, don't trigger
return False
@@ -812,7 +803,7 @@ def make_entity_target_state_trigger(
class CustomTrigger(EntityTargetStateTriggerBase):
"""Trigger for entity state changes."""
_domains = {domain}
_domain_specs = {domain: DomainSpec()}
_to_states = to_states_set
return CustomTrigger
@@ -826,7 +817,7 @@ def make_entity_transition_trigger(
class CustomTrigger(EntityTransitionTriggerBase):
"""Trigger for conditional entity state changes."""
_domains = {domain}
_domain_specs = {domain: DomainSpec()}
_from_states = from_states
_to_states = to_states
@@ -841,36 +832,34 @@ def make_entity_origin_state_trigger(
class CustomTrigger(EntityOriginStateTriggerBase):
"""Trigger for entity "from state" changes."""
_domains = {domain}
_domain_specs = {domain: DomainSpec()}
_from_state = from_state
return CustomTrigger
def make_entity_numerical_state_attribute_changed_trigger(
domains: set[str], attributes: dict[str, str | None]
def make_entity_numerical_state_changed_trigger(
domain_specs: Mapping[str, NumericalDomainSpec],
) -> type[EntityNumericalStateAttributeChangedTriggerBase]:
"""Create a trigger for numerical state attribute change."""
"""Create a trigger for numerical state value change."""
class CustomTrigger(EntityNumericalStateAttributeChangedTriggerBase):
"""Trigger for numerical state attribute changes."""
"""Trigger for numerical state value changes."""
_domains = domains
_attributes = attributes
_domain_specs = domain_specs
return CustomTrigger
def make_entity_numerical_state_attribute_crossed_threshold_trigger(
domains: set[str], attributes: dict[str, str | None]
def make_entity_numerical_state_crossed_threshold_trigger(
domain_specs: Mapping[str, NumericalDomainSpec],
) -> type[EntityNumericalStateAttributeCrossedThresholdTriggerBase]:
"""Create a trigger for numerical state attribute change."""
"""Create a trigger for numerical state value crossing a threshold."""
class CustomTrigger(EntityNumericalStateAttributeCrossedThresholdTriggerBase):
"""Trigger for numerical state attribute changes."""
"""Trigger for numerical state value crossing a threshold."""
_domains = domains
_attributes = attributes
_domain_specs = domain_specs
return CustomTrigger
@@ -883,7 +872,7 @@ def make_entity_target_state_attribute_trigger(
class CustomTrigger(EntityTargetStateAttributeTriggerBase):
"""Trigger for entity state changes."""
_domains = {domain}
_domain_specs = {domain: DomainSpec()}
_attribute = attribute
_attribute_to_state = to_state

View File

@@ -1,5 +1,6 @@
"""The tests for the trigger helper."""
from collections.abc import Mapping
from contextlib import AbstractContextManager, nullcontext as does_not_raise
import io
from typing import Any
@@ -15,6 +16,7 @@ from homeassistant.components.system_health import DOMAIN as SYSTEM_HEALTH_DOMAI
from homeassistant.components.tag import DOMAIN as TAG_DOMAIN
from homeassistant.components.text import DOMAIN as TEXT_DOMAIN
from homeassistant.const import (
ATTR_DEVICE_CLASS,
CONF_ABOVE,
CONF_BELOW,
CONF_ENTITY_ID,
@@ -33,20 +35,27 @@ from homeassistant.core import (
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, trigger
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
from homeassistant.helpers.automation import (
ANY_DEVICE_CLASS,
DomainSpec,
NumericalDomainSpec,
move_top_level_schema_fields_to_options,
)
from homeassistant.helpers.trigger import (
CONF_LOWER_LIMIT,
CONF_THRESHOLD_TYPE,
CONF_UPPER_LIMIT,
DATA_PLUGGABLE_ACTIONS,
EntityTriggerBase,
PluggableAction,
Trigger,
TriggerActionRunner,
TriggerConfig,
_async_get_trigger_platform,
async_initialize_triggers,
async_validate_trigger_config,
make_entity_numerical_state_attribute_changed_trigger,
make_entity_numerical_state_attribute_crossed_threshold_trigger,
make_entity_numerical_state_changed_trigger,
make_entity_numerical_state_crossed_threshold_trigger,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import Integration, async_get_integration
@@ -1242,8 +1251,8 @@ async def test_numerical_state_attribute_changed_trigger_config_validation(
async def async_get_triggers(hass: HomeAssistant) -> dict[str, type[Trigger]]:
return {
"test_trigger": make_entity_numerical_state_attribute_changed_trigger(
{"test"}, {"test": "test_attribute"}
"test_trigger": make_entity_numerical_state_changed_trigger(
{"test": NumericalDomainSpec(value_source="test_attribute")}
),
}
@@ -1270,8 +1279,8 @@ async def test_numerical_state_attribute_changed_error_handling(
async def async_get_triggers(hass: HomeAssistant) -> dict[str, type[Trigger]]:
return {
"attribute_changed": make_entity_numerical_state_attribute_changed_trigger(
{"test"}, {"test": "test_attribute"}
"attribute_changed": make_entity_numerical_state_changed_trigger(
{"test": NumericalDomainSpec(value_source="test_attribute")}
),
}
@@ -1552,8 +1561,8 @@ async def test_numerical_state_attribute_crossed_threshold_trigger_config_valida
async def async_get_triggers(hass: HomeAssistant) -> dict[str, type[Trigger]]:
return {
"test_trigger": make_entity_numerical_state_attribute_crossed_threshold_trigger(
{"test"}, {"test": "test_attribute"}
"test_trigger": make_entity_numerical_state_crossed_threshold_trigger(
{"test": NumericalDomainSpec(value_source="test_attribute")}
),
}
@@ -1571,3 +1580,126 @@ async def test_numerical_state_attribute_crossed_threshold_trigger_config_valida
}
],
)
def _make_trigger(
hass: HomeAssistant, domain_specs: Mapping[str, DomainSpec]
) -> EntityTriggerBase:
"""Create a minimal EntityTriggerBase subclass with the given domain specs."""
class _SimpleTrigger(EntityTriggerBase):
"""Minimal concrete trigger for testing entity_filter."""
_domain_specs = domain_specs
def is_valid_state(self, state):
"""Accept any state."""
return True
config = TriggerConfig(key="test.test_trigger", target={CONF_ENTITY_ID: []})
return _SimpleTrigger(hass, config)
async def test_entity_filter_by_domain_only(hass: HomeAssistant) -> None:
"""Test entity_filter includes entities matching domain, excludes others."""
trig = _make_trigger(hass, {"sensor": DomainSpec(), "switch": DomainSpec()})
entities = {
"sensor.temp",
"sensor.humidity",
"switch.light",
"light.bedroom",
"cover.garage",
}
result = trig.entity_filter(entities)
assert result == {"sensor.temp", "sensor.humidity", "switch.light"}
async def test_entity_filter_by_device_class(hass: HomeAssistant) -> None:
"""Test entity_filter filters by device_class when specified."""
trig = _make_trigger(hass, {"sensor": DomainSpec(device_class="humidity")})
# Set states with device_class attributes
hass.states.async_set("sensor.humidity_1", "50", {ATTR_DEVICE_CLASS: "humidity"})
hass.states.async_set(
"sensor.temperature_1", "22", {ATTR_DEVICE_CLASS: "temperature"}
)
hass.states.async_set("sensor.no_class", "10", {})
entities = {"sensor.humidity_1", "sensor.temperature_1", "sensor.no_class"}
result = trig.entity_filter(entities)
assert result == {"sensor.humidity_1"}
async def test_entity_filter_device_class_unknown_entity(
hass: HomeAssistant,
) -> None:
"""Test entity_filter excludes entities not in state machine or registry."""
trig = _make_trigger(hass, {"sensor": DomainSpec(device_class="humidity")})
# Entity not in state machine and not in entity registry -> UNDEFINED
entities = {"sensor.nonexistent"}
result = trig.entity_filter(entities)
assert result == set()
async def test_entity_filter_multiple_domains_with_device_class(
hass: HomeAssistant,
) -> None:
"""Test entity_filter with multiple domains, some with device_class filtering."""
trig = _make_trigger(
hass,
{
"climate": DomainSpec(value_source="current_humidity"),
"sensor": DomainSpec(device_class="humidity"),
"weather": DomainSpec(value_source="humidity"),
},
)
hass.states.async_set("sensor.humidity", "60", {ATTR_DEVICE_CLASS: "humidity"})
hass.states.async_set(
"sensor.temperature", "20", {ATTR_DEVICE_CLASS: "temperature"}
)
hass.states.async_set("climate.hvac", "heat", {})
hass.states.async_set("weather.home", "sunny", {})
hass.states.async_set("light.bedroom", "on", {})
entities = {
"sensor.humidity",
"sensor.temperature",
"climate.hvac",
"weather.home",
"light.bedroom",
}
result = trig.entity_filter(entities)
# sensor.temperature excluded (wrong device_class)
# light.bedroom excluded (no matching domain)
assert result == {"sensor.humidity", "climate.hvac", "weather.home"}
async def test_entity_filter_no_device_class_means_match_all_in_domain(
hass: HomeAssistant,
) -> None:
"""Test that DomainSpec without device_class matches all entities in the domain."""
trig = _make_trigger(hass, {"cover": DomainSpec()})
hass.states.async_set("cover.door", "open", {ATTR_DEVICE_CLASS: "door"})
hass.states.async_set("cover.garage", "closed", {ATTR_DEVICE_CLASS: "garage"})
hass.states.async_set("cover.plain", "open", {})
entities = {"cover.door", "cover.garage", "cover.plain"}
result = trig.entity_filter(entities)
assert result == entities
async def test_numerical_domain_spec_converter(hass: HomeAssistant) -> None:
"""Test NumericalDomainSpec stores converter correctly."""
converter = lambda v: float(v) / 255.0 * 100.0 # noqa: E731
nvs = NumericalDomainSpec(value_source="brightness", value_converter=converter)
assert nvs.value_source == "brightness"
assert nvs.value_converter is converter
assert nvs.device_class is ANY_DEVICE_CLASS
# Plain DomainSpec has no converter
vs = DomainSpec(value_source="brightness")
assert not isinstance(vs, NumericalDomainSpec)