"""Triggers.""" import abc import asyncio from collections import defaultdict from collections.abc import Callable, Coroutine, Iterable, Mapping from dataclasses import dataclass, field from datetime import timedelta import functools import inspect import logging from typing import ( TYPE_CHECKING, Any, ClassVar, Final, Literal, Protocol, TypedDict, cast, override, ) import voluptuous as vol from homeassistant.const import ( ATTR_ENTITY_ID, ATTR_UNIT_OF_MEASUREMENT, CONF_ALIAS, CONF_DEVICE_ID, CONF_ENABLED, CONF_ENTITY_ID, CONF_EVENT_DATA, CONF_FOR, CONF_ID, CONF_OPTIONS, CONF_PLATFORM, CONF_SELECTOR, CONF_TARGET, CONF_VARIABLES, CONF_ZONE, STATE_UNAVAILABLE, STATE_UNKNOWN, ) from homeassistant.core import ( CALLBACK_TYPE, Context, HassJob, HassJobType, HomeAssistant, State, callback, get_hassjob_callable_job_type, is_callback, valid_entity_id, ) from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.loader import ( Integration, IntegrationNotFound, async_get_integration, async_get_integrations, ) from homeassistant.util.async_ import create_eager_task from homeassistant.util.hass_dict import HassKey from homeassistant.util.unit_conversion import BaseUnitConverter from homeassistant.util.yaml import load_yaml_dict from . import config_validation as cv, selector from .automation import ( DomainSpec, ThresholdConfig, filter_by_domain_specs, get_absolute_description_key, get_relative_description_key, move_options_fields_to_top_level, ) from .event import async_track_same_state from .integration_platform import async_process_integration_platforms from .selector import ( NumericThresholdMode, NumericThresholdSelector, NumericThresholdSelectorConfig, NumericThresholdType, TargetSelector, ) from .target import ( TargetStateChangedData, async_track_target_selector_state_change_event, ) from .template import Template from .typing import UNDEFINED, ConfigType, TemplateVarsType, UndefinedType _LOGGER = logging.getLogger(__name__) _PLATFORM_ALIASES = { "device": "device_automation", "event": "homeassistant", "numeric_state": "homeassistant", "state": "homeassistant", "time_pattern": "homeassistant", "time": "homeassistant", } DATA_PLUGGABLE_ACTIONS: HassKey[defaultdict[tuple, PluggableActionsEntry]] = HassKey( "pluggable_actions" ) TRIGGER_DESCRIPTION_CACHE: HassKey[dict[str, dict[str, Any] | None]] = HassKey( "trigger_description_cache" ) TRIGGER_DISABLED_TRIGGERS: HassKey[set[str]] = HassKey("trigger_disabled_triggers") TRIGGER_PLATFORM_SUBSCRIPTIONS: HassKey[ list[Callable[[set[str]], Coroutine[Any, Any, None]]] ] = HassKey("trigger_platform_subscriptions") TRIGGERS: HassKey[dict[str, str]] = HassKey("triggers") # Basic schemas to sanity check the trigger descriptions, # full validation is done by hassfest.triggers _FIELD_DESCRIPTION_SCHEMA = vol.Schema( { vol.Optional(CONF_SELECTOR): selector.validate_selector, }, extra=vol.ALLOW_EXTRA, ) _TRIGGER_DESCRIPTION_SCHEMA = vol.Schema( { vol.Optional("target"): TargetSelector.CONFIG_SCHEMA, vol.Optional("fields"): vol.Schema({str: _FIELD_DESCRIPTION_SCHEMA}), }, extra=vol.ALLOW_EXTRA, ) def starts_with_dot(key: str) -> str: """Check if key starts with dot.""" if not key.startswith("."): raise vol.Invalid("Key does not start with .") return key _TRIGGERS_DESCRIPTION_SCHEMA = vol.Schema( { vol.Remove(vol.All(str, starts_with_dot)): object, cv.underscore_slug: vol.Any(None, _TRIGGER_DESCRIPTION_SCHEMA), } ) async def async_setup(hass: HomeAssistant) -> None: """Set up the trigger helper.""" from homeassistant.components import automation, labs # noqa: PLC0415 hass.data[TRIGGER_DESCRIPTION_CACHE] = {} hass.data[TRIGGER_DISABLED_TRIGGERS] = set() hass.data[TRIGGER_PLATFORM_SUBSCRIPTIONS] = [] hass.data[TRIGGERS] = {} async def new_triggers_conditions_listener( _event_data: labs.EventLabsUpdatedData, ) -> None: """Handle new_triggers_conditions flag change.""" # Invalidate the cache hass.data[TRIGGER_DESCRIPTION_CACHE] = {} hass.data[TRIGGER_DISABLED_TRIGGERS] = set() labs.async_subscribe_preview_feature( hass, automation.DOMAIN, automation.NEW_TRIGGERS_CONDITIONS_FEATURE_FLAG, new_triggers_conditions_listener, ) await async_process_integration_platforms( hass, "trigger", _register_trigger_platform, wait_for_platforms=True ) @callback def async_subscribe_platform_events( hass: HomeAssistant, on_event: Callable[[set[str]], Coroutine[Any, Any, None]], ) -> Callable[[], None]: """Subscribe to trigger platform events.""" trigger_platform_event_subscriptions = hass.data[TRIGGER_PLATFORM_SUBSCRIPTIONS] def remove_subscription() -> None: trigger_platform_event_subscriptions.remove(on_event) trigger_platform_event_subscriptions.append(on_event) return remove_subscription async def _register_trigger_platform( hass: HomeAssistant, integration_domain: str, platform: TriggerProtocol ) -> None: """Register a trigger platform and notify listeners. If the trigger platform does not provide any triggers, or it is disabled, listeners will not be notified. """ from homeassistant.components import automation # noqa: PLC0415 new_triggers: set[str] = set() triggers = hass.data[TRIGGERS] if hasattr(platform, "async_get_triggers"): all_triggers = await platform.async_get_triggers(hass) for trigger_key in all_triggers: trigger_key = get_absolute_description_key(integration_domain, trigger_key) if trigger_key not in triggers: triggers[trigger_key] = integration_domain new_triggers.add(trigger_key) if not new_triggers: if not all_triggers: _LOGGER.debug( "Integration %s returned no triggers in async_get_triggers", integration_domain, ) return elif hasattr(platform, "async_validate_trigger_config") or hasattr( platform, "TRIGGER_SCHEMA" ): if integration_domain in triggers: return triggers[integration_domain] = integration_domain new_triggers.add(integration_domain) else: _LOGGER.debug( "Integration %s does not provide trigger support, skipping", integration_domain, ) return if automation.is_disabled_experimental_trigger(hass, integration_domain): _LOGGER.debug("Triggers for integration %s are disabled", integration_domain) return # We don't use gather here because gather adds additional overhead # when wrapping each coroutine in a task, and we expect our listeners # to call trigger.async_get_all_descriptions which will only yield # the first time it's called, after that it returns cached data. for listener in hass.data[TRIGGER_PLATFORM_SUBSCRIPTIONS]: try: await listener(new_triggers) except Exception: _LOGGER.exception("Error while notifying trigger platform listener") _TRIGGER_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend( { vol.Optional(CONF_OPTIONS): object, vol.Optional(CONF_TARGET): cv.TARGET_FIELDS, } ) class Trigger(abc.ABC): """Trigger class.""" _hass: HomeAssistant @classmethod async def async_validate_complete_config( cls, hass: HomeAssistant, complete_config: ConfigType ) -> ConfigType: """Validate complete config. The complete config includes fields that are generic to all triggers, such as the alias or the ID. This method should be overridden by triggers that need to migrate from the old-style config. """ complete_config = _TRIGGER_SCHEMA(complete_config) specific_config: ConfigType = {} for key in (CONF_OPTIONS, CONF_TARGET): if key in complete_config: specific_config[key] = complete_config.pop(key) specific_config = await cls.async_validate_config(hass, specific_config) for key in (CONF_OPTIONS, CONF_TARGET): if key in specific_config: complete_config[key] = specific_config[key] return complete_config @classmethod @abc.abstractmethod async def async_validate_config( cls, hass: HomeAssistant, config: ConfigType ) -> ConfigType: """Validate config.""" def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None: """Initialize trigger.""" self._hass = hass async def async_attach_action( self, action: TriggerAction, action_payload_builder: TriggerActionPayloadBuilder, ) -> CALLBACK_TYPE: """Attach the trigger to an action.""" @callback def run_action( extra_trigger_payload: dict[str, Any], description: str, context: Context | None = None, ) -> asyncio.Task[Any]: """Run action with trigger variables.""" payload = action_payload_builder(extra_trigger_payload, description) return self._hass.async_create_task(action(payload, context)) return await self.async_attach_runner(run_action) @abc.abstractmethod async def async_attach_runner( self, run_action: TriggerActionRunner ) -> CALLBACK_TYPE: """Attach the trigger to an action runner.""" ATTR_BEHAVIOR: Final = "behavior" BEHAVIOR_FIRST: Final = "first" BEHAVIOR_ALL: Final = "all" BEHAVIOR_EACH: Final = "each" def _backwards_compatible_behavior(value: Any) -> Any: """Convert legacy behavior values to new ones.""" if value == "any": return BEHAVIOR_EACH if value == "last": return BEHAVIOR_ALL return value ENTITY_STATE_TRIGGER_SCHEMA = vol.Schema( { vol.Required(CONF_TARGET): cv.TARGET_FIELDS, vol.Required(CONF_OPTIONS, default={}): {}, } ) ENTITY_STATE_TRIGGER_SCHEMA_WITH_BEHAVIOR = ENTITY_STATE_TRIGGER_SCHEMA.extend( { vol.Required(CONF_OPTIONS, default={}): { vol.Required(ATTR_BEHAVIOR, default=BEHAVIOR_EACH): vol.All( _backwards_compatible_behavior, vol.In([BEHAVIOR_FIRST, BEHAVIOR_ALL, BEHAVIOR_EACH]), ), vol.Optional(CONF_FOR): cv.positive_time_period, }, } ) class EntityTriggerBase(Trigger): """Trigger for entity state changes.""" _domain_specs: Mapping[str, DomainSpec] # States filtered from the to_state pre-filter (and `_should_include`). _excluded_states: Final[frozenset[str]] = frozenset( {STATE_UNAVAILABLE, STATE_UNKNOWN} ) # States filtered from the from_state pre-filter. Defaults to # `_excluded_states`. Subclasses can override to relax the origin # check. _excluded_from_states: ClassVar[frozenset[str]] = _excluded_states _schema: vol.Schema = ENTITY_STATE_TRIGGER_SCHEMA_WITH_BEHAVIOR # When True, indirect target expansion (via device/area/floor) skips # entities with an entity_category. _primary_entities_only: ClassVar[bool] = True @override @classmethod async def async_validate_config( cls, hass: HomeAssistant, config: ConfigType ) -> ConfigType: """Validate config.""" return cast(ConfigType, cls._schema(config)) def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None: """Initialize the state trigger.""" super().__init__(hass, config) if TYPE_CHECKING: assert config.target is not None self._options = config.options or {} self._duration: timedelta | None = self._options.get(CONF_FOR) self._target = config.target def entity_filter(self, entities: set[str]) -> set[str]: """Filter entities matching any of the domain specs.""" return filter_by_domain_specs(self._hass, self._domain_specs, entities) def _get_tracked_value(self, state: State) -> Any: """Get the tracked value from a state based on the DomainSpec.""" domain_spec = self._domain_specs[state.domain] if domain_spec.value_source is None: return state.state return state.attributes.get(domain_spec.value_source) def is_valid_transition(self, from_state: State, to_state: State) -> bool: """Check if the transition should fire the trigger. Called only after `from_state.state` has been filtered against `_excluded_from_states` and `to_state.state` against `_excluded_states`, so subclasses don't need to repeat those checks. Default: any state change. Override to add semantics (specific from/to states, value changed across a threshold, etc.). """ return from_state.state != to_state.state def is_valid_state(self, state: State) -> bool: """Check if the state is a target state for the trigger. Called only after `state.state` has been filtered against `_excluded_states`, so subclasses don't need to repeat that check. Default: any non-excluded state is a target. Override to restrict (specific to_states, value within a threshold, etc.). """ return True def _should_include(self, state: State) -> bool: """Check if an entity should participate in all/count checks. The default implementation excludes only entities whose state.state is in `_excluded_states` (unavailable / unknown). Subclasses can override to also exclude entities that lack the optional capability the trigger relies on (e.g. a missing volume_level attribute). """ return state.state not in self._excluded_states def count_matches(self, entity_ids: set[str]) -> tuple[int, int]: """Return (matches, included) for the entity set. `matches` is the number of entities that pass `_should_include` AND `is_valid_state`. `included` is the number that pass `_should_include` (i.e. are visible to the all/count check at all). Callers can use the pair to distinguish vacuous truth (`included == 0`) from a genuine all-match (`matches == included > 0`). """ matches = 0 included = 0 for entity_id in entity_ids: state = self._hass.states.get(entity_id) if state is None or not self._should_include(state): continue included += 1 if self.is_valid_state(state): matches += 1 return matches, included @override async def async_attach_runner( self, run_action: TriggerActionRunner ) -> CALLBACK_TYPE: """Attach the trigger to an action runner.""" behavior: str = self._options.get(ATTR_BEHAVIOR, BEHAVIOR_EACH) unsub_track_same: dict[str, Callable[[], None]] = {} @callback def state_change_listener( target_state_change_data: TargetStateChangedData, ) -> None: """Listen for state changes and call action.""" event = target_state_change_data.state_change_event entity_id = event.data["entity_id"] from_state = event.data["old_state"] to_state = event.data["new_state"] def state_still_valid( _: str, from_state: State | None, to_state: State | None ) -> bool: """Check if the state is still valid during the duration wait. Called by async_track_same_state on each state change to determine whether to cancel the timer. For behavior each, checks the individual entity's state. For behavior first/all, checks the combined state. """ if behavior == BEHAVIOR_ALL: matches, included = self.count_matches( target_state_change_data.targeted_entity_ids ) # Require at least one included entity to avoid keeping # the timer alive when every targeted entity has been # filtered out since it started — a vacuous all-match # (`included == 0`) would otherwise let the action fire # after `for:` even though no entity still matches. return included > 0 and matches == included if behavior == BEHAVIOR_FIRST: matches, _included = self.count_matches( target_state_change_data.targeted_entity_ids ) return matches >= 1 # Behavior each: check the individual entity's state if not to_state or to_state.state in self._excluded_states: return False return self.is_valid_state(to_state) if not from_state or not to_state: return # The trigger should never fire if the new state is excluded # or not a target state. if to_state.state in self._excluded_states or not self.is_valid_state( to_state ): return # The trigger should never fire if the origin state is excluded # or the transition is not valid. if ( from_state.state in self._excluded_from_states or not self.is_valid_transition(from_state, to_state) ): return if behavior == BEHAVIOR_ALL: matches, included = self.count_matches( target_state_change_data.targeted_entity_ids ) if matches != included: return elif behavior == BEHAVIOR_FIRST: # Note: It's enough to test for exactly 1 match here because if there # were previously 2 matches the transition would not be valid and we # would have returned already. matches, _ = self.count_matches( target_state_change_data.targeted_entity_ids ) if matches != 1: return @callback def call_action() -> None: """Call action with right context.""" # After a `for` delay, keep the original triggering event payload. # `async_track_same_state` only verifies the state remained valid # for the configured duration before firing the action. run_action( { ATTR_ENTITY_ID: entity_id, "from_state": from_state, "to_state": to_state, "for": self._duration, }, f"state of {entity_id}", event.context, ) if not self._duration: # Call action immediately if duration is not specified or 0 call_action() return subscription_key = entity_id if behavior == BEHAVIOR_EACH else behavior if subscription_key in unsub_track_same: unsub_track_same.pop(subscription_key)() unsub_track_same[subscription_key] = async_track_same_state( self._hass, self._duration, call_action, state_still_valid, entity_ids=( entity_id if behavior == BEHAVIOR_EACH else target_state_change_data.targeted_entity_ids ), ) unsub = async_track_target_selector_state_change_event( self._hass, self._target, state_change_listener, self.entity_filter, primary_entities_only=self._primary_entities_only, ) @callback def async_remove() -> None: """Remove state listeners async.""" unsub() for async_remove in unsub_track_same.values(): async_remove() unsub_track_same.clear() return async_remove class EntityTargetStateTriggerBase(EntityTriggerBase): """Trigger for entity state changes to a specific state. Uses _get_tracked_value to extract the value, so it works for both state-based and attribute-based triggers depending on the DomainSpec. """ _to_states: set[str] def is_valid_transition(self, from_state: State, to_state: State) -> bool: """Check the value changed and the origin was not already a target state.""" from_value = self._get_tracked_value(from_state) return ( from_value != self._get_tracked_value(to_state) and from_value not in self._to_states ) def is_valid_state(self, state: State) -> bool: """Check if the new state matches the expected state.""" return self._get_tracked_value(state) in self._to_states class EntityTransitionTriggerBase(EntityTriggerBase): """Trigger for entity state changes between specific states.""" _from_states: set[str | bool] _to_states: set[str | bool] def is_valid_transition(self, from_state: State, to_state: State) -> bool: """Check if the origin state matches the expected ones.""" from_value = self._get_tracked_value(from_state) return ( from_value != self._get_tracked_value(to_state) and from_value in self._from_states ) def is_valid_state(self, state: State) -> bool: """Check if the new state matches the expected states.""" return self._get_tracked_value(state) in self._to_states class EntityOriginStateTriggerBase(EntityTriggerBase): """Trigger for entity state changes from a specific state.""" _from_state: str def is_valid_transition(self, from_state: State, to_state: State) -> bool: """Check if origin state matches expected and that the state changed.""" return bool( self._get_tracked_value(from_state) == self._from_state and self._get_tracked_value(to_state) != self._from_state ) def is_valid_state(self, state: State) -> bool: """Check that the new state is different from the origin state.""" return bool(self._get_tracked_value(state) != self._from_state) class StatelessEntityTriggerBase(EntityTriggerBase): """Trigger for entities that don't carry meaningful state. Used for stateless entities (buttons, scenes, doorbells, events) whose `state.state` is just a timestamp of the last activation. `STATE_UNKNOWN` is a legitimate prior state — the first activation after startup must still fire the trigger. """ _schema: vol.Schema = ENTITY_STATE_TRIGGER_SCHEMA _excluded_from_states: ClassVar[frozenset[str]] = frozenset({STATE_UNAVAILABLE}) NUMERICAL_ATTRIBUTE_CHANGED_TRIGGER_SCHEMA = ENTITY_STATE_TRIGGER_SCHEMA.extend( { vol.Required(CONF_OPTIONS, default={}): vol.All( { vol.Required("threshold"): NumericThresholdSelector( NumericThresholdSelectorConfig(mode=NumericThresholdMode.CHANGED) ) }, ) } ) class EntityNumericalStateTriggerBase(EntityTriggerBase): """Base class for numerical state and state attribute triggers.""" _valid_unit: str | None | UndefinedType = UNDEFINED _threshold_type: NumericThresholdType def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None: """Initialize the state trigger.""" super().__init__(hass, config) threshold_options: dict[str, Any] = self._options["threshold"] self.threshold = ThresholdConfig.from_config(threshold_options.get("value")) self.lower_threshold = ThresholdConfig.from_config( threshold_options.get("value_min") ) self.upper_threshold = ThresholdConfig.from_config( threshold_options.get("value_max") ) self._threshold_type = threshold_options["type"] def _is_valid_unit(self, unit: str | None) -> bool: """Check if the given unit is valid for this trigger.""" if isinstance(self._valid_unit, UndefinedType): return True return unit == self._valid_unit def _get_threshold_value(self, threshold: ThresholdConfig | None) -> float | None: """Get threshold value from float or entity state.""" if threshold is None: return None if threshold.numerical: return threshold.number if not (state := self._hass.states.get(threshold.entity)): # type: ignore[arg-type] # Entity not found return None if not self._is_valid_unit(state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)): # Entity unit does not match the expected unit return None try: return float(state.state) except TypeError, ValueError: # Entity state is not a valid number return None def _get_tracked_value(self, state: State) -> float | None: """Get the tracked numerical value from a state.""" domain_spec = self._domain_specs[state.domain] raw_value: Any if domain_spec.value_source is None: if not self._is_valid_unit(state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)): return None raw_value = state.state else: raw_value = state.attributes.get(domain_spec.value_source) try: return float(raw_value) except TypeError, ValueError: # Entity state is not a valid number return None def is_valid_state(self, state: State) -> bool: """Check if the new state or state attribute matches the expected one.""" # Handle missing or None value case first to avoid expensive exceptions if (current_value := self._get_tracked_value(state)) is None: return False if self._threshold_type == NumericThresholdType.ANY: # If the threshold type is "any" we always trigger on valid state # changes return True if self._threshold_type == NumericThresholdType.ABOVE: if (limit := self._get_threshold_value(self.threshold)) is None: # Entity not found or invalid number, don't trigger return False return current_value > limit if self._threshold_type == NumericThresholdType.BELOW: if (limit := self._get_threshold_value(self.threshold)) is None: # Entity not found or invalid number, don't trigger return False return current_value < limit # Mode is BETWEEN or OUTSIDE lower_limit = self._get_threshold_value(self.lower_threshold) upper_limit = self._get_threshold_value(self.upper_threshold) if lower_limit is None or upper_limit is None: # Entity not found or invalid number, don't trigger return False between = lower_limit <= current_value <= upper_limit if self._threshold_type == NumericThresholdType.BETWEEN: return between return not between class EntityNumericalStateTriggerWithUnitBase(EntityNumericalStateTriggerBase): """Base class for numerical state and state attribute triggers.""" _base_unit: str | None # Base unit for the tracked value _unit_converter: type[BaseUnitConverter] def _get_entity_unit(self, state: State) -> str | None: """Get the unit of an entity from its state.""" return state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) def _get_threshold_value(self, threshold: ThresholdConfig | None) -> float | None: """Get threshold value from float or entity state.""" if threshold is None: return None if threshold.numerical: return self._unit_converter.convert( threshold.number, # type: ignore[arg-type] threshold.unit, # type: ignore[arg-type] self._base_unit, ) if not (state := self._hass.states.get(threshold.entity)): # type: ignore[arg-type] # Entity not found return None try: value = float(state.state) except TypeError, ValueError: # Entity state is not a valid number return None try: return self._unit_converter.convert( value, state.attributes.get(ATTR_UNIT_OF_MEASUREMENT), self._base_unit ) except HomeAssistantError: # Unit conversion failed (i.e. incompatible units), treat as invalid number return None def _get_tracked_value(self, state: State) -> float | None: """Get the tracked numerical value from a state.""" domain_spec = self._domain_specs[state.domain] raw_value: Any if domain_spec.value_source is None: raw_value = state.state else: raw_value = state.attributes.get(domain_spec.value_source) try: value = float(raw_value) except TypeError, ValueError: # Entity state is not a valid number return None try: return self._unit_converter.convert( value, self._get_entity_unit(state), self._base_unit ) except HomeAssistantError: # Unit conversion failed (i.e. incompatible units), treat as invalid number return None class EntityNumericalStateChangedTriggerBase(EntityNumericalStateTriggerBase): """Trigger for numerical state and state attribute changes.""" _schema = NUMERICAL_ATTRIBUTE_CHANGED_TRIGGER_SCHEMA def is_valid_transition(self, from_state: State, to_state: State) -> bool: """Check if the tracked numeric value has changed.""" return self._get_tracked_value(from_state) != self._get_tracked_value(to_state) def make_numerical_state_changed_with_unit_schema( unit_converter: type[BaseUnitConverter], ) -> vol.Schema: """Factory for numerical state trigger schema with unit option.""" return ENTITY_STATE_TRIGGER_SCHEMA.extend( { vol.Required(CONF_OPTIONS, default={}): vol.All( { vol.Required("threshold"): NumericThresholdSelector( NumericThresholdSelectorConfig( mode=NumericThresholdMode.CHANGED, unit_of_measurement=list(unit_converter.VALID_UNITS), ) ) }, ) } ) class EntityNumericalStateChangedTriggerWithUnitBase( EntityNumericalStateChangedTriggerBase, EntityNumericalStateTriggerWithUnitBase, ): """Trigger for numerical state and state attribute changes.""" def __init_subclass__(cls, **kwargs: Any) -> None: """Create a schema.""" super().__init_subclass__(**kwargs) cls._schema = make_numerical_state_changed_with_unit_schema(cls._unit_converter) NUMERICAL_ATTRIBUTE_CROSSED_THRESHOLD_SCHEMA = ( ENTITY_STATE_TRIGGER_SCHEMA_WITH_BEHAVIOR.extend( { vol.Required(CONF_OPTIONS): { vol.Required("threshold"): NumericThresholdSelector( NumericThresholdSelectorConfig(mode=NumericThresholdMode.CROSSED) ), }, } ) ) class EntityNumericalStateCrossedThresholdTriggerBase(EntityNumericalStateTriggerBase): """Trigger for numerical state and state attribute changes. This trigger only fires when the observed attribute changes from not within to within the defined threshold. """ _schema = NUMERICAL_ATTRIBUTE_CROSSED_THRESHOLD_SCHEMA def is_valid_transition(self, from_state: State, to_state: State) -> bool: """Check that the tracked value crossed into the threshold range.""" return not self.is_valid_state(from_state) def _make_numerical_state_crossed_threshold_with_unit_schema( unit_converter: type[BaseUnitConverter], ) -> vol.Schema: """Trigger for numerical state and state attribute changes. This trigger only fires when the observed attribute changes from not within to within the defined threshold. """ return ENTITY_STATE_TRIGGER_SCHEMA_WITH_BEHAVIOR.extend( { vol.Required(CONF_OPTIONS, default={}): { vol.Required("threshold"): NumericThresholdSelector( NumericThresholdSelectorConfig( mode=NumericThresholdMode.CROSSED, unit_of_measurement=list(unit_converter.VALID_UNITS), ) ), }, } ) class EntityNumericalStateCrossedThresholdTriggerWithUnitBase( EntityNumericalStateCrossedThresholdTriggerBase, EntityNumericalStateTriggerWithUnitBase, ): """Trigger for numerical state and state attribute changes.""" def __init_subclass__(cls, **kwargs: Any) -> None: """Create a schema.""" super().__init_subclass__(**kwargs) cls._schema = _make_numerical_state_crossed_threshold_with_unit_schema( cls._unit_converter ) def _normalize_domain_specs( domain_specs: Mapping[str, DomainSpec] | str, ) -> Mapping[str, DomainSpec]: """Normalize domain_specs argument to a Mapping.""" if isinstance(domain_specs, str): return {domain_specs: DomainSpec()} return domain_specs def make_entity_target_state_trigger( domain_specs: Mapping[str, DomainSpec] | str, to_states: str | set[str], *, primary_entities_only: bool = True, ) -> type[EntityTargetStateTriggerBase]: """Create a trigger for entity state changes to specific state(s). domain_specs can be a string (domain name) for simple state-based triggers, or a Mapping[str, DomainSpec] for attribute-based or multi-domain triggers. """ specs = _normalize_domain_specs(domain_specs) if isinstance(to_states, str): to_states_set = {to_states} else: to_states_set = to_states class CustomTrigger(EntityTargetStateTriggerBase): """Trigger for entity state changes.""" _domain_specs = specs _to_states = to_states_set _primary_entities_only = primary_entities_only return CustomTrigger def make_entity_transition_trigger( domain_specs: Mapping[str, DomainSpec] | str, *, from_states: set[str | bool], to_states: set[str | bool], ) -> type[EntityTransitionTriggerBase]: """Create a trigger for entity state changes between specific states. domain_specs can be a string (domain name) for simple state-based triggers, or a Mapping[str, DomainSpec] for attribute-based or multi-domain triggers. """ specs = _normalize_domain_specs(domain_specs) class CustomTrigger(EntityTransitionTriggerBase): """Trigger for conditional entity state changes.""" _domain_specs = specs _from_states = from_states _to_states = to_states return CustomTrigger def make_entity_origin_state_trigger( domain_specs: Mapping[str, DomainSpec] | str, *, from_state: str, ) -> type[EntityOriginStateTriggerBase]: """Create a trigger for entity state changes from a specific state. domain_specs can be a string (domain name) for simple state-based triggers, or a Mapping[str, DomainSpec] for attribute-based or multi-domain triggers. """ specs = _normalize_domain_specs(domain_specs) class CustomTrigger(EntityOriginStateTriggerBase): """Trigger for entity "from state" changes.""" _domain_specs = specs _from_state = from_state return CustomTrigger def make_entity_numerical_state_changed_trigger( domain_specs: Mapping[str, DomainSpec], valid_unit: str | None | UndefinedType = UNDEFINED, *, primary_entities_only: bool = True, ) -> type[EntityNumericalStateChangedTriggerBase]: """Create a trigger for numerical state value change.""" class CustomTrigger(EntityNumericalStateChangedTriggerBase): """Trigger for numerical state value changes.""" _domain_specs = domain_specs _valid_unit = valid_unit _primary_entities_only = primary_entities_only return CustomTrigger def make_entity_numerical_state_crossed_threshold_trigger( domain_specs: Mapping[str, DomainSpec], valid_unit: str | None | UndefinedType = UNDEFINED, *, primary_entities_only: bool = True, ) -> type[EntityNumericalStateCrossedThresholdTriggerBase]: """Create a trigger for numerical state value crossing a threshold.""" class CustomTrigger(EntityNumericalStateCrossedThresholdTriggerBase): """Trigger for numerical state value crossing a threshold.""" _domain_specs = domain_specs _valid_unit = valid_unit _primary_entities_only = primary_entities_only return CustomTrigger def make_entity_numerical_state_changed_with_unit_trigger( domain_specs: Mapping[str, DomainSpec], base_unit: str, unit_converter: type[BaseUnitConverter], ) -> type[EntityNumericalStateChangedTriggerWithUnitBase]: """Create a trigger for numerical state value change.""" class CustomTrigger(EntityNumericalStateChangedTriggerWithUnitBase): """Trigger for numerical state value changes.""" _domain_specs = domain_specs _base_unit = base_unit _unit_converter = unit_converter return CustomTrigger def make_entity_numerical_state_crossed_threshold_with_unit_trigger( domain_specs: Mapping[str, DomainSpec], base_unit: str, unit_converter: type[BaseUnitConverter], ) -> type[EntityNumericalStateCrossedThresholdTriggerWithUnitBase]: """Create a trigger for numerical state value crossing a threshold.""" class CustomTrigger(EntityNumericalStateCrossedThresholdTriggerWithUnitBase): """Trigger for numerical state value crossing a threshold.""" _domain_specs = domain_specs _base_unit = base_unit _unit_converter = unit_converter return CustomTrigger class TriggerProtocol(Protocol): """Define the format of trigger modules. New implementations should only implement async_get_triggers. """ async def async_get_triggers(self, hass: HomeAssistant) -> dict[str, type[Trigger]]: """Return the triggers provided by this integration.""" TRIGGER_SCHEMA: vol.Schema async def async_validate_trigger_config( self, hass: HomeAssistant, config: ConfigType ) -> ConfigType: """Validate config.""" async def async_attach_trigger( self, hass: HomeAssistant, config: ConfigType, action: TriggerActionType, trigger_info: TriggerInfo, ) -> CALLBACK_TYPE: """Attach a trigger.""" @dataclass(slots=True, frozen=True) class TriggerConfig: """Trigger config.""" key: str # The key used to identify the trigger, e.g. "zwave.event" target: dict[str, Any] | None = None options: dict[str, Any] | None = None class TriggerActionRunner(Protocol): """Protocol type for the trigger action runner helper callback.""" @callback def __call__( self, extra_trigger_payload: dict[str, Any], description: str, context: Context | None = None, ) -> asyncio.Task[Any]: """Define trigger action runner type. Returns: A Task that allows awaiting for the action to finish. """ class TriggerActionPayloadBuilder(Protocol): """Protocol type for the trigger action payload builder.""" def __call__( self, extra_trigger_payload: dict[str, Any], description: str ) -> dict[str, Any]: """Define trigger action payload builder type.""" class TriggerAction(Protocol): """Protocol type for trigger action callback.""" async def __call__( self, run_variables: dict[str, Any], context: Context | None = None ) -> Any: """Define action callback type.""" class TriggerActionType(Protocol): """Protocol type for trigger action callback. Contrary to TriggerAction, this type supports both sync and async callables. """ def __call__( self, run_variables: dict[str, Any], context: Context | None = None, ) -> Coroutine[Any, Any, Any] | Any: """Define action callback type.""" class TriggerData(TypedDict): """Trigger data.""" id: str idx: str alias: str | None class TriggerInfo(TypedDict): """Information about trigger.""" domain: str name: str home_assistant_start: bool variables: TemplateVarsType trigger_data: TriggerData @dataclass(slots=True) class PluggableActionsEntry: """Holder to keep track of all plugs and actions for a given trigger.""" plugs: set[PluggableAction] = field(default_factory=set) actions: dict[ object, tuple[ HassJob[[dict[str, Any], Context | None], Coroutine[Any, Any, None] | Any], dict[str, Any], ], ] = field(default_factory=dict) class PluggableAction: """A pluggable action handler.""" _entry: PluggableActionsEntry | None = None def __init__(self, update: CALLBACK_TYPE | None = None) -> None: """Initialize a pluggable action. :param update: callback triggered whenever triggers are attached or removed. """ self._update = update def __bool__(self) -> bool: """Return if we have something attached.""" return bool(self._entry and self._entry.actions) @callback def async_run_update(self) -> None: """Run update function if one exists.""" if self._update: self._update() @staticmethod @callback def async_get_registry(hass: HomeAssistant) -> dict[tuple, PluggableActionsEntry]: """Return the pluggable actions registry.""" if data := hass.data.get(DATA_PLUGGABLE_ACTIONS): return data data = hass.data[DATA_PLUGGABLE_ACTIONS] = defaultdict(PluggableActionsEntry) return data @staticmethod @callback def async_attach_trigger( hass: HomeAssistant, trigger: dict[str, str], action: TriggerActionType, variables: dict[str, Any], ) -> CALLBACK_TYPE: """Attach an action to a trigger entry. Existing or future plugs registered will be attached. """ reg = PluggableAction.async_get_registry(hass) key = tuple(sorted(trigger.items())) entry = reg[key] def _update() -> None: for plug in entry.plugs: plug.async_run_update() @callback def _remove() -> None: """Remove this action attachment, and disconnect all plugs.""" del entry.actions[_remove] _update() if not entry.actions and not entry.plugs: del reg[key] job = HassJob(action, f"trigger {trigger} {variables}") entry.actions[_remove] = (job, variables) _update() return _remove @callback def async_register( self, hass: HomeAssistant, trigger: dict[str, str] ) -> CALLBACK_TYPE: """Register plug in the global plugs dictionary.""" reg = PluggableAction.async_get_registry(hass) key = tuple(sorted(trigger.items())) self._entry = reg[key] self._entry.plugs.add(self) @callback def _remove() -> None: """Remove plug from registration. Clean up entry if there are no actions or plugs registered. """ assert self._entry self._entry.plugs.remove(self) if not self._entry.actions and not self._entry.plugs: del reg[key] self._entry = None return _remove async def async_run( self, hass: HomeAssistant, context: Context | None = None ) -> None: """Run all actions.""" assert self._entry for job, variables in self._entry.actions.values(): task = hass.async_run_hass_job(job, variables, context) if task: await task async def _async_get_trigger_platform( hass: HomeAssistant, trigger_key: str ) -> tuple[str, TriggerProtocol]: from homeassistant.components import automation # noqa: PLC0415 platform_and_sub_type = trigger_key.split(".") platform = platform_and_sub_type[0] # Only apply aliases for old-style triggers (no sub_type). # New-style triggers (e.g. "event.received") use the integration domain directly. if len(platform_and_sub_type) == 1: platform = _PLATFORM_ALIASES.get(platform, platform) if automation.is_disabled_experimental_trigger(hass, platform): raise vol.Invalid( f"Trigger '{trigger_key}' requires the experimental 'New triggers and " "conditions' feature to be enabled in Home Assistant Labs settings " f"(feature flag: '{automation.NEW_TRIGGERS_CONDITIONS_FEATURE_FLAG}')" ) try: integration = await async_get_integration(hass, platform) except IntegrationNotFound: raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified") from None try: platform_module = await integration.async_get_platform("trigger") except ImportError: raise vol.Invalid( f"Integration '{platform}' does not provide trigger support" ) from None # Ensure triggers are registered so descriptions can be loaded await _register_trigger_platform(hass, platform, platform_module) return platform, platform_module async def async_validate_trigger_config( hass: HomeAssistant, trigger_config: list[ConfigType] ) -> list[ConfigType]: """Validate triggers.""" config = [] for conf in trigger_config: trigger_key: str = conf[CONF_PLATFORM] platform_domain, platform = await _async_get_trigger_platform(hass, trigger_key) if hasattr(platform, "async_get_triggers"): trigger_descriptors = await platform.async_get_triggers(hass) relative_trigger_key = get_relative_description_key( platform_domain, trigger_key ) if not (trigger := trigger_descriptors.get(relative_trigger_key)): raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified") conf = await trigger.async_validate_complete_config(hass, conf) elif hasattr(platform, "async_validate_trigger_config"): conf = move_options_fields_to_top_level(conf, cv.TRIGGER_BASE_SCHEMA) conf = await platform.async_validate_trigger_config(hass, conf) else: conf = move_options_fields_to_top_level(conf, cv.TRIGGER_BASE_SCHEMA) conf = platform.TRIGGER_SCHEMA(conf) config.append(conf) return config def _trigger_action_wrapper( hass: HomeAssistant, action: Callable, conf: ConfigType ) -> Callable: """Wrap trigger action with extra vars if configured. If action is a coroutine function, a coroutine function will be returned. If action is a callback, a callback will be returned. """ if CONF_VARIABLES not in conf: return action # Check for partials to properly determine if coroutine function check_func = action while isinstance(check_func, functools.partial): check_func = check_func.func wrapper_func: Callable[..., Any] | Callable[..., Coroutine[Any, Any, Any]] if inspect.iscoroutinefunction(check_func): async_action = cast(Callable[..., Coroutine[Any, Any, Any]], action) @functools.wraps(async_action) async def async_with_vars( run_variables: dict[str, Any], context: Context | None = None ) -> Any: """Wrap action with extra vars.""" trigger_variables = conf[CONF_VARIABLES] run_variables.update(trigger_variables.async_render(hass, run_variables)) return await action(run_variables, context) wrapper_func = async_with_vars else: @functools.wraps(action) def with_vars( run_variables: dict[str, Any], context: Context | None = None ) -> Any: """Wrap action with extra vars.""" trigger_variables = conf[CONF_VARIABLES] run_variables.update(trigger_variables.async_render(hass, run_variables)) return action(run_variables, context) if is_callback(check_func): with_vars = callback(with_vars) wrapper_func = with_vars return wrapper_func async def _async_attach_trigger_cls( hass: HomeAssistant, trigger_cls: type[Trigger], trigger_key: str, conf: ConfigType, action: Callable, trigger_info: TriggerInfo, ) -> CALLBACK_TYPE: """Initialize a new Trigger class and attach it.""" def action_payload_builder( extra_trigger_payload: dict[str, Any], description: str ) -> dict[str, Any]: """Build action variables.""" payload = { "trigger": { **trigger_info["trigger_data"], CONF_PLATFORM: trigger_key, "description": description, **extra_trigger_payload, } } if CONF_VARIABLES in conf: trigger_variables = conf[CONF_VARIABLES] payload.update(trigger_variables.async_render(hass, payload)) return payload # Wrap sync action so that it is always async. # This simplifies the Trigger action runner interface by # always returning a coroutine, removing the need for # integrations to check for the return type when awaiting # the action. match get_hassjob_callable_job_type(action): case HassJobType.Executor: original_action = action async def wrapped_executor_action( run_variables: dict[str, Any], context: Context | None = None ) -> Any: """Wrap sync action to be called in executor.""" return await hass.async_add_executor_job( original_action, run_variables, context ) action = wrapped_executor_action case HassJobType.Callback: original_action = action async def wrapped_callback_action( run_variables: dict[str, Any], context: Context | None = None ) -> Any: """Wrap callback action to be awaitable.""" return original_action(run_variables, context) action = wrapped_callback_action trigger = trigger_cls( hass, TriggerConfig( key=trigger_key, target=conf.get(CONF_TARGET), options=conf.get(CONF_OPTIONS), ), ) return await trigger.async_attach_action(action, action_payload_builder) async def async_initialize_triggers( hass: HomeAssistant, trigger_config: list[ConfigType], action: Callable, domain: str, name: str, log_cb: Callable, home_assistant_start: bool = False, variables: TemplateVarsType = None, ) -> CALLBACK_TYPE | None: """Initialize triggers.""" triggers: list[asyncio.Task[CALLBACK_TYPE]] = [] for idx, conf in enumerate(trigger_config): # Skip triggers that are not enabled if CONF_ENABLED in conf: enabled = conf[CONF_ENABLED] if isinstance(enabled, Template): try: enabled = enabled.async_render(variables, limited=True) except TemplateError as err: log_cb(logging.ERROR, f"Error rendering enabled template: {err}") continue if not enabled: continue trigger_key: str = conf[CONF_PLATFORM] platform_domain, platform = await _async_get_trigger_platform(hass, trigger_key) trigger_id = conf.get(CONF_ID, f"{idx}") trigger_idx = f"{idx}" trigger_alias = conf.get(CONF_ALIAS) trigger_data = TriggerData(id=trigger_id, idx=trigger_idx, alias=trigger_alias) info = TriggerInfo( domain=domain, name=name, home_assistant_start=home_assistant_start, variables=variables, trigger_data=trigger_data, ) if hasattr(platform, "async_get_triggers"): trigger_descriptors = await platform.async_get_triggers(hass) relative_trigger_key = get_relative_description_key( platform_domain, trigger_key ) trigger_cls = trigger_descriptors[relative_trigger_key] coro = _async_attach_trigger_cls( hass, trigger_cls, trigger_key, conf, action, info ) else: action_wrapper = _trigger_action_wrapper(hass, action, conf) coro = platform.async_attach_trigger(hass, conf, action_wrapper, info) triggers.append(create_eager_task(coro)) attach_results = await asyncio.gather(*triggers, return_exceptions=True) removes: list[Callable[[], None]] = [] for result in attach_results: if isinstance(result, HomeAssistantError): log_cb(logging.ERROR, f"Got error '{result}' when setting up triggers for") elif isinstance(result, Exception): log_cb(logging.ERROR, "Error setting up trigger", exc_info=result) elif isinstance(result, BaseException): raise result from None elif result is None: log_cb( # type: ignore[unreachable] logging.ERROR, "Unknown error while setting up trigger (empty result)" ) else: removes.append(result) if not removes: return None log_cb(logging.INFO, "Initialized trigger") @callback def remove_triggers() -> None: """Remove triggers.""" for remove in removes: remove() return remove_triggers def _load_triggers_file(integration: Integration) -> dict[str, Any]: """Load triggers file for an integration.""" try: return cast( dict[str, Any], _TRIGGERS_DESCRIPTION_SCHEMA( load_yaml_dict(str(integration.file_path / "triggers.yaml")) ), ) except FileNotFoundError: _LOGGER.warning( "Unable to find triggers.yaml for the %s integration", integration.domain ) return {} except (HomeAssistantError, vol.Invalid) as ex: _LOGGER.warning( "Unable to parse triggers.yaml for the %s integration: %s", integration.domain, ex, ) return {} def _load_triggers_files( integrations: Iterable[Integration], ) -> dict[str, dict[str, Any]]: """Load trigger files for multiple integrations.""" return { integration.domain: { get_absolute_description_key(integration.domain, key): value for key, value in _load_triggers_file(integration).items() } for integration in integrations } async def async_get_all_descriptions( hass: HomeAssistant, ) -> dict[str, dict[str, Any] | None]: """Return descriptions (i.e. user documentation) for all triggers.""" from homeassistant.components import automation # noqa: PLC0415 descriptions_cache = hass.data[TRIGGER_DESCRIPTION_CACHE] triggers = hass.data[TRIGGERS] # See if there are new triggers not seen before. # Any trigger that we saw before already has an entry in description_cache. all_triggers = set(triggers) previous_all_triggers = set(descriptions_cache) # If the triggers are the same, we can return the cache # mypy complains: Invalid index type "HassKey[set[str]]" for "HassDict" if previous_all_triggers | hass.data[TRIGGER_DISABLED_TRIGGERS] == all_triggers: # type: ignore[index] return descriptions_cache # Files we loaded for missing descriptions new_triggers_descriptions: dict[str, dict[str, Any]] = {} # We try to avoid making a copy in the event the cache is good, # but now we must make a copy in case new triggers get added # while we are loading the missing ones so we do not # add the new ones to the cache without their descriptions triggers = triggers.copy() if missing_triggers := all_triggers.difference(descriptions_cache): domains_with_missing_triggers = { triggers[missing_trigger] for missing_trigger in missing_triggers } ints_or_excs = await async_get_integrations(hass, domains_with_missing_triggers) integrations: list[Integration] = [] for domain, int_or_exc in ints_or_excs.items(): if type(int_or_exc) is Integration and int_or_exc.has_triggers: integrations.append(int_or_exc) continue if TYPE_CHECKING: assert isinstance(int_or_exc, Exception) _LOGGER.debug( "Failed to load triggers.yaml for integration: %s", domain, exc_info=int_or_exc, ) if integrations: new_triggers_descriptions = await hass.async_add_executor_job( _load_triggers_files, integrations ) # Make a copy of the old cache and add missing descriptions to it new_descriptions_cache = descriptions_cache.copy() for missing_trigger in missing_triggers: domain = triggers[missing_trigger] if automation.is_disabled_experimental_trigger(hass, domain): hass.data[TRIGGER_DISABLED_TRIGGERS].add(missing_trigger) continue if ( yaml_description := new_triggers_descriptions.get(domain, {}).get( missing_trigger ) ) is None: _LOGGER.debug( "No trigger descriptions found for trigger %s, skipping", missing_trigger, ) new_descriptions_cache[missing_trigger] = None continue description = {"fields": yaml_description.get("fields", {})} if (target := yaml_description.get("target")) is not None: description["target"] = target new_descriptions_cache[missing_trigger] = description hass.data[TRIGGER_DESCRIPTION_CACHE] = new_descriptions_cache return new_descriptions_cache @callback def async_extract_devices(trigger_conf: dict) -> list[str]: """Extract devices from a trigger config.""" if trigger_conf[CONF_PLATFORM] == "device": return [trigger_conf[CONF_DEVICE_ID]] if ( trigger_conf[CONF_PLATFORM] == "event" and CONF_EVENT_DATA in trigger_conf and CONF_DEVICE_ID in trigger_conf[CONF_EVENT_DATA] and isinstance(trigger_conf[CONF_EVENT_DATA][CONF_DEVICE_ID], str) ): return [trigger_conf[CONF_EVENT_DATA][CONF_DEVICE_ID]] if trigger_conf[CONF_PLATFORM] == "tag" and CONF_DEVICE_ID in trigger_conf: return trigger_conf[CONF_DEVICE_ID] # type: ignore[no-any-return] if target_devices := async_extract_targets(trigger_conf, CONF_DEVICE_ID): return target_devices return [] @callback def async_extract_entities(trigger_conf: dict) -> list[str]: """Extract entities from a trigger config.""" if trigger_conf[CONF_PLATFORM] in ("state", "numeric_state"): return trigger_conf[CONF_ENTITY_ID] # type: ignore[no-any-return] if trigger_conf[CONF_PLATFORM] == "calendar": return [trigger_conf[CONF_OPTIONS][CONF_ENTITY_ID]] if trigger_conf[CONF_PLATFORM] == "zone": options = trigger_conf[CONF_OPTIONS] return [*options[CONF_ENTITY_ID], options[CONF_ZONE]] if trigger_conf[CONF_PLATFORM] in ("zone.entered", "zone.left"): return [ *async_extract_targets(trigger_conf, CONF_ENTITY_ID), trigger_conf[CONF_OPTIONS][CONF_ZONE], ] if trigger_conf[CONF_PLATFORM] == "geo_location": return [trigger_conf[CONF_ZONE]] if trigger_conf[CONF_PLATFORM] == "sun": return ["sun.sun"] if ( trigger_conf[CONF_PLATFORM] == "event" and CONF_EVENT_DATA in trigger_conf and CONF_ENTITY_ID in trigger_conf[CONF_EVENT_DATA] and isinstance(trigger_conf[CONF_EVENT_DATA][CONF_ENTITY_ID], str) and valid_entity_id(trigger_conf[CONF_EVENT_DATA][CONF_ENTITY_ID]) ): return [trigger_conf[CONF_EVENT_DATA][CONF_ENTITY_ID]] if target_entities := async_extract_targets(trigger_conf, CONF_ENTITY_ID): return target_entities return [] @callback def async_extract_targets( config: dict, target: Literal["entity_id", "device_id", "area_id", "floor_id", "label_id"], ) -> list[str]: """Extract targets from a target config.""" if not (target_conf := config.get(CONF_TARGET)): return [] if not (targets := target_conf.get(target)): return [] return [targets] if isinstance(targets, str) else targets