mirror of
https://github.com/home-assistant/core.git
synced 2026-04-02 08:26:41 +01:00
1615 lines
54 KiB
Python
1615 lines
54 KiB
Python
"""Triggers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import abc
|
|
import asyncio
|
|
from collections import defaultdict
|
|
from collections.abc import Callable, Coroutine, Iterable, Mapping
|
|
from dataclasses import dataclass, field
|
|
import functools
|
|
import inspect
|
|
import logging
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Final,
|
|
Literal,
|
|
Protocol,
|
|
TypedDict,
|
|
cast,
|
|
override,
|
|
)
|
|
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.const import (
|
|
ATTR_ENTITY_ID,
|
|
ATTR_UNIT_OF_MEASUREMENT,
|
|
CONF_ALIAS,
|
|
CONF_DEVICE_ID,
|
|
CONF_ENABLED,
|
|
CONF_ENTITY_ID,
|
|
CONF_EVENT_DATA,
|
|
CONF_ID,
|
|
CONF_OPTIONS,
|
|
CONF_PLATFORM,
|
|
CONF_SELECTOR,
|
|
CONF_TARGET,
|
|
CONF_VARIABLES,
|
|
CONF_ZONE,
|
|
STATE_UNAVAILABLE,
|
|
STATE_UNKNOWN,
|
|
)
|
|
from homeassistant.core import (
|
|
CALLBACK_TYPE,
|
|
Context,
|
|
HassJob,
|
|
HassJobType,
|
|
HomeAssistant,
|
|
State,
|
|
callback,
|
|
get_hassjob_callable_job_type,
|
|
is_callback,
|
|
valid_entity_id,
|
|
)
|
|
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
|
from homeassistant.loader import (
|
|
Integration,
|
|
IntegrationNotFound,
|
|
async_get_integration,
|
|
async_get_integrations,
|
|
)
|
|
from homeassistant.util.async_ import create_eager_task
|
|
from homeassistant.util.hass_dict import HassKey
|
|
from homeassistant.util.unit_conversion import BaseUnitConverter
|
|
from homeassistant.util.yaml import load_yaml_dict
|
|
|
|
from . import config_validation as cv, selector
|
|
from .automation import (
|
|
DomainSpec,
|
|
ThresholdConfig,
|
|
filter_by_domain_specs,
|
|
get_absolute_description_key,
|
|
get_relative_description_key,
|
|
move_options_fields_to_top_level,
|
|
)
|
|
from .integration_platform import async_process_integration_platforms
|
|
from .selector import (
|
|
NumericThresholdMode,
|
|
NumericThresholdSelector,
|
|
NumericThresholdSelectorConfig,
|
|
NumericThresholdType,
|
|
TargetSelector,
|
|
)
|
|
from .target import (
|
|
TargetStateChangedData,
|
|
async_track_target_selector_state_change_event,
|
|
)
|
|
from .template import Template
|
|
from .typing import UNDEFINED, ConfigType, TemplateVarsType, UndefinedType
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
_PLATFORM_ALIASES = {
|
|
"device": "device_automation",
|
|
"event": "homeassistant",
|
|
"numeric_state": "homeassistant",
|
|
"state": "homeassistant",
|
|
"time_pattern": "homeassistant",
|
|
"time": "homeassistant",
|
|
}
|
|
|
|
DATA_PLUGGABLE_ACTIONS: HassKey[defaultdict[tuple, PluggableActionsEntry]] = HassKey(
|
|
"pluggable_actions"
|
|
)
|
|
|
|
TRIGGER_DESCRIPTION_CACHE: HassKey[dict[str, dict[str, Any] | None]] = HassKey(
|
|
"trigger_description_cache"
|
|
)
|
|
TRIGGER_DISABLED_TRIGGERS: HassKey[set[str]] = HassKey("trigger_disabled_triggers")
|
|
TRIGGER_PLATFORM_SUBSCRIPTIONS: HassKey[
|
|
list[Callable[[set[str]], Coroutine[Any, Any, None]]]
|
|
] = HassKey("trigger_platform_subscriptions")
|
|
TRIGGERS: HassKey[dict[str, str]] = HassKey("triggers")
|
|
|
|
|
|
# Basic schemas to sanity check the trigger descriptions,
|
|
# full validation is done by hassfest.triggers
|
|
_FIELD_DESCRIPTION_SCHEMA = vol.Schema(
|
|
{
|
|
vol.Optional(CONF_SELECTOR): selector.validate_selector,
|
|
},
|
|
extra=vol.ALLOW_EXTRA,
|
|
)
|
|
|
|
_TRIGGER_DESCRIPTION_SCHEMA = vol.Schema(
|
|
{
|
|
vol.Optional("target"): TargetSelector.CONFIG_SCHEMA,
|
|
vol.Optional("fields"): vol.Schema({str: _FIELD_DESCRIPTION_SCHEMA}),
|
|
},
|
|
extra=vol.ALLOW_EXTRA,
|
|
)
|
|
|
|
|
|
def starts_with_dot(key: str) -> str:
|
|
"""Check if key starts with dot."""
|
|
if not key.startswith("."):
|
|
raise vol.Invalid("Key does not start with .")
|
|
return key
|
|
|
|
|
|
_TRIGGERS_DESCRIPTION_SCHEMA = vol.Schema(
|
|
{
|
|
vol.Remove(vol.All(str, starts_with_dot)): object,
|
|
cv.underscore_slug: vol.Any(None, _TRIGGER_DESCRIPTION_SCHEMA),
|
|
}
|
|
)
|
|
|
|
|
|
async def async_setup(hass: HomeAssistant) -> None:
|
|
"""Set up the trigger helper."""
|
|
from homeassistant.components import automation, labs # noqa: PLC0415
|
|
|
|
hass.data[TRIGGER_DESCRIPTION_CACHE] = {}
|
|
hass.data[TRIGGER_DISABLED_TRIGGERS] = set()
|
|
hass.data[TRIGGER_PLATFORM_SUBSCRIPTIONS] = []
|
|
hass.data[TRIGGERS] = {}
|
|
|
|
async def new_triggers_conditions_listener(
|
|
_event_data: labs.EventLabsUpdatedData,
|
|
) -> None:
|
|
"""Handle new_triggers_conditions flag change."""
|
|
# Invalidate the cache
|
|
hass.data[TRIGGER_DESCRIPTION_CACHE] = {}
|
|
hass.data[TRIGGER_DISABLED_TRIGGERS] = set()
|
|
|
|
labs.async_subscribe_preview_feature(
|
|
hass,
|
|
automation.DOMAIN,
|
|
automation.NEW_TRIGGERS_CONDITIONS_FEATURE_FLAG,
|
|
new_triggers_conditions_listener,
|
|
)
|
|
|
|
await async_process_integration_platforms(
|
|
hass, "trigger", _register_trigger_platform, wait_for_platforms=True
|
|
)
|
|
|
|
|
|
@callback
|
|
def async_subscribe_platform_events(
|
|
hass: HomeAssistant,
|
|
on_event: Callable[[set[str]], Coroutine[Any, Any, None]],
|
|
) -> Callable[[], None]:
|
|
"""Subscribe to trigger platform events."""
|
|
trigger_platform_event_subscriptions = hass.data[TRIGGER_PLATFORM_SUBSCRIPTIONS]
|
|
|
|
def remove_subscription() -> None:
|
|
trigger_platform_event_subscriptions.remove(on_event)
|
|
|
|
trigger_platform_event_subscriptions.append(on_event)
|
|
return remove_subscription
|
|
|
|
|
|
async def _register_trigger_platform(
|
|
hass: HomeAssistant, integration_domain: str, platform: TriggerProtocol
|
|
) -> None:
|
|
"""Register a trigger platform and notify listeners.
|
|
|
|
If the trigger platform does not provide any triggers, or it is disabled,
|
|
listeners will not be notified.
|
|
"""
|
|
from homeassistant.components import automation # noqa: PLC0415
|
|
|
|
new_triggers: set[str] = set()
|
|
triggers = hass.data[TRIGGERS]
|
|
|
|
if hasattr(platform, "async_get_triggers"):
|
|
all_triggers = await platform.async_get_triggers(hass)
|
|
for trigger_key in all_triggers:
|
|
trigger_key = get_absolute_description_key(integration_domain, trigger_key)
|
|
if trigger_key not in triggers:
|
|
triggers[trigger_key] = integration_domain
|
|
new_triggers.add(trigger_key)
|
|
if not new_triggers:
|
|
if not all_triggers:
|
|
_LOGGER.debug(
|
|
"Integration %s returned no triggers in async_get_triggers",
|
|
integration_domain,
|
|
)
|
|
return
|
|
elif hasattr(platform, "async_validate_trigger_config") or hasattr(
|
|
platform, "TRIGGER_SCHEMA"
|
|
):
|
|
if integration_domain in triggers:
|
|
return
|
|
triggers[integration_domain] = integration_domain
|
|
new_triggers.add(integration_domain)
|
|
else:
|
|
_LOGGER.debug(
|
|
"Integration %s does not provide trigger support, skipping",
|
|
integration_domain,
|
|
)
|
|
return
|
|
|
|
if automation.is_disabled_experimental_trigger(hass, integration_domain):
|
|
_LOGGER.debug("Triggers for integration %s are disabled", integration_domain)
|
|
return
|
|
|
|
# We don't use gather here because gather adds additional overhead
|
|
# when wrapping each coroutine in a task, and we expect our listeners
|
|
# to call trigger.async_get_all_descriptions which will only yield
|
|
# the first time it's called, after that it returns cached data.
|
|
for listener in hass.data[TRIGGER_PLATFORM_SUBSCRIPTIONS]:
|
|
try:
|
|
await listener(new_triggers)
|
|
except Exception:
|
|
_LOGGER.exception("Error while notifying trigger platform listener")
|
|
|
|
|
|
_TRIGGER_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend(
|
|
{
|
|
vol.Optional(CONF_OPTIONS): object,
|
|
vol.Optional(CONF_TARGET): cv.TARGET_FIELDS,
|
|
}
|
|
)
|
|
|
|
|
|
class Trigger(abc.ABC):
|
|
"""Trigger class."""
|
|
|
|
_hass: HomeAssistant
|
|
|
|
@classmethod
|
|
async def async_validate_complete_config(
|
|
cls, hass: HomeAssistant, complete_config: ConfigType
|
|
) -> ConfigType:
|
|
"""Validate complete config.
|
|
|
|
The complete config includes fields that are generic to all triggers,
|
|
such as the alias or the ID.
|
|
This method should be overridden by triggers that need to migrate
|
|
from the old-style config.
|
|
"""
|
|
complete_config = _TRIGGER_SCHEMA(complete_config)
|
|
|
|
specific_config: ConfigType = {}
|
|
for key in (CONF_OPTIONS, CONF_TARGET):
|
|
if key in complete_config:
|
|
specific_config[key] = complete_config.pop(key)
|
|
specific_config = await cls.async_validate_config(hass, specific_config)
|
|
|
|
for key in (CONF_OPTIONS, CONF_TARGET):
|
|
if key in specific_config:
|
|
complete_config[key] = specific_config[key]
|
|
|
|
return complete_config
|
|
|
|
@classmethod
|
|
@abc.abstractmethod
|
|
async def async_validate_config(
|
|
cls, hass: HomeAssistant, config: ConfigType
|
|
) -> ConfigType:
|
|
"""Validate config."""
|
|
|
|
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
|
|
"""Initialize trigger."""
|
|
self._hass = hass
|
|
|
|
async def async_attach_action(
|
|
self,
|
|
action: TriggerAction,
|
|
action_payload_builder: TriggerActionPayloadBuilder,
|
|
) -> CALLBACK_TYPE:
|
|
"""Attach the trigger to an action."""
|
|
|
|
@callback
|
|
def run_action(
|
|
extra_trigger_payload: dict[str, Any],
|
|
description: str,
|
|
context: Context | None = None,
|
|
) -> asyncio.Task[Any]:
|
|
"""Run action with trigger variables."""
|
|
|
|
payload = action_payload_builder(extra_trigger_payload, description)
|
|
return self._hass.async_create_task(action(payload, context))
|
|
|
|
return await self.async_attach_runner(run_action)
|
|
|
|
@abc.abstractmethod
|
|
async def async_attach_runner(
|
|
self, run_action: TriggerActionRunner
|
|
) -> CALLBACK_TYPE:
|
|
"""Attach the trigger to an action runner."""
|
|
|
|
|
|
ATTR_BEHAVIOR: Final = "behavior"
|
|
BEHAVIOR_FIRST: Final = "first"
|
|
BEHAVIOR_LAST: Final = "last"
|
|
BEHAVIOR_ANY: Final = "any"
|
|
|
|
ENTITY_STATE_TRIGGER_SCHEMA = vol.Schema(
|
|
{
|
|
vol.Required(CONF_TARGET): cv.TARGET_FIELDS,
|
|
}
|
|
)
|
|
|
|
ENTITY_STATE_TRIGGER_SCHEMA_FIRST_LAST = ENTITY_STATE_TRIGGER_SCHEMA.extend(
|
|
{
|
|
vol.Required(CONF_OPTIONS): {
|
|
vol.Required(ATTR_BEHAVIOR, default=BEHAVIOR_ANY): vol.In(
|
|
[BEHAVIOR_FIRST, BEHAVIOR_LAST, BEHAVIOR_ANY]
|
|
),
|
|
},
|
|
}
|
|
)
|
|
|
|
|
|
class EntityTriggerBase(Trigger):
|
|
"""Trigger for entity state changes."""
|
|
|
|
_domain_specs: Mapping[str, DomainSpec]
|
|
_schema: vol.Schema = ENTITY_STATE_TRIGGER_SCHEMA_FIRST_LAST
|
|
|
|
@override
|
|
@classmethod
|
|
async def async_validate_config(
|
|
cls, hass: HomeAssistant, config: ConfigType
|
|
) -> ConfigType:
|
|
"""Validate config."""
|
|
return cast(ConfigType, cls._schema(config))
|
|
|
|
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
|
|
"""Initialize the state trigger."""
|
|
super().__init__(hass, config)
|
|
if TYPE_CHECKING:
|
|
assert config.target is not None
|
|
self._options = config.options or {}
|
|
self._target = config.target
|
|
|
|
def entity_filter(self, entities: set[str]) -> set[str]:
|
|
"""Filter entities matching any of the domain specs."""
|
|
return filter_by_domain_specs(self._hass, self._domain_specs, entities)
|
|
|
|
def _get_tracked_value(self, state: State) -> Any:
|
|
"""Get the tracked value from a state based on the DomainSpec."""
|
|
domain_spec = self._domain_specs[state.domain]
|
|
if domain_spec.value_source is None:
|
|
return state.state
|
|
return state.attributes.get(domain_spec.value_source)
|
|
|
|
@abc.abstractmethod
|
|
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
|
|
"""Check if the origin state is valid and the state has changed."""
|
|
|
|
@abc.abstractmethod
|
|
def is_valid_state(self, state: State) -> bool:
|
|
"""Check if the new state matches the expected state(s)."""
|
|
|
|
def check_all_match(self, entity_ids: set[str]) -> bool:
|
|
"""Check if all entity states match."""
|
|
return all(
|
|
self.is_valid_state(state)
|
|
for entity_id in entity_ids
|
|
if (state := self._hass.states.get(entity_id)) is not None
|
|
)
|
|
|
|
def check_one_match(self, entity_ids: set[str]) -> bool:
|
|
"""Check that only one entity state matches."""
|
|
return (
|
|
sum(
|
|
self.is_valid_state(state)
|
|
for entity_id in entity_ids
|
|
if (state := self._hass.states.get(entity_id)) is not None
|
|
)
|
|
== 1
|
|
)
|
|
|
|
@override
|
|
async def async_attach_runner(
|
|
self, run_action: TriggerActionRunner
|
|
) -> CALLBACK_TYPE:
|
|
"""Attach the trigger to an action runner."""
|
|
|
|
behavior = self._options.get(ATTR_BEHAVIOR)
|
|
|
|
@callback
|
|
def state_change_listener(
|
|
target_state_change_data: TargetStateChangedData,
|
|
) -> None:
|
|
"""Listen for state changes and call action."""
|
|
event = target_state_change_data.state_change_event
|
|
entity_id = event.data["entity_id"]
|
|
from_state = event.data["old_state"]
|
|
to_state = event.data["new_state"]
|
|
|
|
if not from_state or not to_state:
|
|
return
|
|
|
|
# The trigger should never fire if the new state is not valid
|
|
if not self.is_valid_state(to_state):
|
|
return
|
|
|
|
# The trigger should never fire if the transition is not valid
|
|
if not self.is_valid_transition(from_state, to_state):
|
|
return
|
|
|
|
if behavior == BEHAVIOR_LAST:
|
|
if not self.check_all_match(
|
|
target_state_change_data.targeted_entity_ids
|
|
):
|
|
return
|
|
elif behavior == BEHAVIOR_FIRST:
|
|
if not self.check_one_match(
|
|
target_state_change_data.targeted_entity_ids
|
|
):
|
|
return
|
|
|
|
run_action(
|
|
{
|
|
ATTR_ENTITY_ID: entity_id,
|
|
"from_state": from_state,
|
|
"to_state": to_state,
|
|
},
|
|
f"state of {entity_id}",
|
|
event.context,
|
|
)
|
|
|
|
return async_track_target_selector_state_change_event(
|
|
self._hass, self._target, state_change_listener, self.entity_filter
|
|
)
|
|
|
|
|
|
class EntityTargetStateTriggerBase(EntityTriggerBase):
|
|
"""Trigger for entity state changes to a specific state.
|
|
|
|
Uses _get_tracked_value to extract the value, so it works for both
|
|
state-based and attribute-based triggers depending on the DomainSpec.
|
|
"""
|
|
|
|
_to_states: set[str]
|
|
|
|
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
|
|
"""Check if the origin state is valid and the state has changed."""
|
|
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
|
|
return False
|
|
|
|
from_value = self._get_tracked_value(from_state)
|
|
return (
|
|
from_value != self._get_tracked_value(to_state)
|
|
and from_value not in self._to_states
|
|
)
|
|
|
|
def is_valid_state(self, state: State) -> bool:
|
|
"""Check if the new state matches the expected state."""
|
|
return self._get_tracked_value(state) in self._to_states
|
|
|
|
|
|
class EntityTransitionTriggerBase(EntityTriggerBase):
|
|
"""Trigger for entity state changes between specific states."""
|
|
|
|
_from_states: set[str | bool]
|
|
_to_states: set[str | bool]
|
|
|
|
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
|
|
"""Check if the origin state matches the expected ones."""
|
|
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
|
|
return False
|
|
|
|
from_value = self._get_tracked_value(from_state)
|
|
return (
|
|
from_value != self._get_tracked_value(to_state)
|
|
and from_value in self._from_states
|
|
)
|
|
|
|
def is_valid_state(self, state: State) -> bool:
|
|
"""Check if the new state matches the expected states."""
|
|
return self._get_tracked_value(state) in self._to_states
|
|
|
|
|
|
class EntityOriginStateTriggerBase(EntityTriggerBase):
|
|
"""Trigger for entity state changes from a specific state."""
|
|
|
|
_from_state: str
|
|
|
|
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
|
|
"""Check if the origin state matches the expected one and that the state changed."""
|
|
return bool(
|
|
self._get_tracked_value(from_state) == self._from_state
|
|
and self._get_tracked_value(to_state) != self._from_state
|
|
)
|
|
|
|
def is_valid_state(self, state: State) -> bool:
|
|
"""Check if the new state is valid."""
|
|
return state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN) and bool(
|
|
self._get_tracked_value(state) != self._from_state
|
|
)
|
|
|
|
|
|
NUMERICAL_ATTRIBUTE_CHANGED_TRIGGER_SCHEMA = ENTITY_STATE_TRIGGER_SCHEMA.extend(
|
|
{
|
|
vol.Required(CONF_OPTIONS, default={}): vol.All(
|
|
{
|
|
vol.Required("threshold"): NumericThresholdSelector(
|
|
NumericThresholdSelectorConfig(mode=NumericThresholdMode.CHANGED)
|
|
)
|
|
},
|
|
)
|
|
}
|
|
)
|
|
|
|
|
|
class EntityNumericalStateTriggerBase(EntityTriggerBase):
|
|
"""Base class for numerical state and state attribute triggers."""
|
|
|
|
_valid_unit: str | None | UndefinedType = UNDEFINED
|
|
_threshold_type: NumericThresholdType
|
|
|
|
def __init__(self, hass: HomeAssistant, config: TriggerConfig) -> None:
|
|
"""Initialize the state trigger."""
|
|
super().__init__(hass, config)
|
|
threshold_options: dict[str, Any] = self._options["threshold"]
|
|
self.threshold = ThresholdConfig.from_config(threshold_options.get("value"))
|
|
self.lower_threshold = ThresholdConfig.from_config(
|
|
threshold_options.get("value_min")
|
|
)
|
|
self.upper_threshold = ThresholdConfig.from_config(
|
|
threshold_options.get("value_max")
|
|
)
|
|
self._threshold_type = threshold_options["type"]
|
|
|
|
def _is_valid_unit(self, unit: str | None) -> bool:
|
|
"""Check if the given unit is valid for this trigger."""
|
|
if isinstance(self._valid_unit, UndefinedType):
|
|
return True
|
|
return unit == self._valid_unit
|
|
|
|
def _get_threshold_value(self, threshold: ThresholdConfig | None) -> float | None:
|
|
"""Get threshold value from float or entity state."""
|
|
if threshold is None:
|
|
return None
|
|
if threshold.numerical:
|
|
return threshold.number
|
|
|
|
if not (state := self._hass.states.get(threshold.entity)): # type: ignore[arg-type]
|
|
# Entity not found
|
|
return None
|
|
if not self._is_valid_unit(state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)):
|
|
# Entity unit does not match the expected unit
|
|
return None
|
|
try:
|
|
return float(state.state)
|
|
except TypeError, ValueError:
|
|
# Entity state is not a valid number
|
|
return None
|
|
|
|
def _get_tracked_value(self, state: State) -> float | None:
|
|
"""Get the tracked numerical value from a state."""
|
|
domain_spec = self._domain_specs[state.domain]
|
|
raw_value: Any
|
|
if domain_spec.value_source is None:
|
|
if not self._is_valid_unit(state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)):
|
|
return None
|
|
raw_value = state.state
|
|
else:
|
|
raw_value = state.attributes.get(domain_spec.value_source)
|
|
|
|
try:
|
|
return float(raw_value)
|
|
except TypeError, ValueError:
|
|
# Entity state is not a valid number
|
|
return None
|
|
|
|
def is_valid_state(self, state: State) -> bool:
|
|
"""Check if the new state or state attribute matches the expected one."""
|
|
# Handle missing or None value case first to avoid expensive exceptions
|
|
if (current_value := self._get_tracked_value(state)) is None:
|
|
return False
|
|
|
|
if self._threshold_type == NumericThresholdType.ANY:
|
|
# If the threshold type is "any" we always trigger on valid state
|
|
# changes
|
|
return True
|
|
|
|
if self._threshold_type == NumericThresholdType.ABOVE:
|
|
if (limit := self._get_threshold_value(self.threshold)) is None:
|
|
# Entity not found or invalid number, don't trigger
|
|
return False
|
|
return current_value > limit
|
|
if self._threshold_type == NumericThresholdType.BELOW:
|
|
if (limit := self._get_threshold_value(self.threshold)) is None:
|
|
# Entity not found or invalid number, don't trigger
|
|
return False
|
|
return current_value < limit
|
|
|
|
# Mode is BETWEEN or OUTSIDE
|
|
lower_limit = self._get_threshold_value(self.lower_threshold)
|
|
upper_limit = self._get_threshold_value(self.upper_threshold)
|
|
if lower_limit is None or upper_limit is None:
|
|
# Entity not found or invalid number, don't trigger
|
|
return False
|
|
between = lower_limit < current_value < upper_limit
|
|
if self._threshold_type == NumericThresholdType.BETWEEN:
|
|
return between
|
|
return not between
|
|
|
|
|
|
class EntityNumericalStateTriggerWithUnitBase(EntityNumericalStateTriggerBase):
|
|
"""Base class for numerical state and state attribute triggers."""
|
|
|
|
_base_unit: str | None # Base unit for the tracked value
|
|
_unit_converter: type[BaseUnitConverter]
|
|
|
|
def _get_entity_unit(self, state: State) -> str | None:
|
|
"""Get the unit of an entity from its state."""
|
|
return state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)
|
|
|
|
def _get_threshold_value(self, threshold: ThresholdConfig | None) -> float | None:
|
|
"""Get threshold value from float or entity state."""
|
|
if threshold is None:
|
|
return None
|
|
if threshold.numerical:
|
|
return self._unit_converter.convert(
|
|
threshold.number, # type: ignore[arg-type]
|
|
threshold.unit, # type: ignore[arg-type]
|
|
self._base_unit,
|
|
)
|
|
|
|
if not (state := self._hass.states.get(threshold.entity)): # type: ignore[arg-type]
|
|
# Entity not found
|
|
return None
|
|
try:
|
|
value = float(state.state)
|
|
except TypeError, ValueError:
|
|
# Entity state is not a valid number
|
|
return None
|
|
|
|
try:
|
|
return self._unit_converter.convert(
|
|
value, state.attributes.get(ATTR_UNIT_OF_MEASUREMENT), self._base_unit
|
|
)
|
|
except HomeAssistantError:
|
|
# Unit conversion failed (i.e. incompatible units), treat as invalid number
|
|
return None
|
|
|
|
def _get_tracked_value(self, state: State) -> float | None:
|
|
"""Get the tracked numerical value from a state."""
|
|
domain_spec = self._domain_specs[state.domain]
|
|
raw_value: Any
|
|
if domain_spec.value_source is None:
|
|
raw_value = state.state
|
|
else:
|
|
raw_value = state.attributes.get(domain_spec.value_source)
|
|
|
|
try:
|
|
value = float(raw_value)
|
|
except TypeError, ValueError:
|
|
# Entity state is not a valid number
|
|
return None
|
|
|
|
try:
|
|
return self._unit_converter.convert(
|
|
value, self._get_entity_unit(state), self._base_unit
|
|
)
|
|
except HomeAssistantError:
|
|
# Unit conversion failed (i.e. incompatible units), treat as invalid number
|
|
return None
|
|
|
|
|
|
class EntityNumericalStateChangedTriggerBase(EntityNumericalStateTriggerBase):
|
|
"""Trigger for numerical state and state attribute changes."""
|
|
|
|
_schema = NUMERICAL_ATTRIBUTE_CHANGED_TRIGGER_SCHEMA
|
|
|
|
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
|
|
"""Check if the origin state is valid and the state has changed."""
|
|
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
|
|
return False
|
|
|
|
return self._get_tracked_value(from_state) != self._get_tracked_value(to_state)
|
|
|
|
|
|
def make_numerical_state_changed_with_unit_schema(
|
|
unit_converter: type[BaseUnitConverter],
|
|
) -> vol.Schema:
|
|
"""Factory for numerical state trigger schema with unit option."""
|
|
return ENTITY_STATE_TRIGGER_SCHEMA.extend(
|
|
{
|
|
vol.Required(CONF_OPTIONS, default={}): vol.All(
|
|
{
|
|
vol.Required("threshold"): NumericThresholdSelector(
|
|
NumericThresholdSelectorConfig(
|
|
mode=NumericThresholdMode.CHANGED,
|
|
unit_of_measurement=list(unit_converter.VALID_UNITS),
|
|
)
|
|
)
|
|
},
|
|
)
|
|
}
|
|
)
|
|
|
|
|
|
class EntityNumericalStateChangedTriggerWithUnitBase(
|
|
EntityNumericalStateChangedTriggerBase,
|
|
EntityNumericalStateTriggerWithUnitBase,
|
|
):
|
|
"""Trigger for numerical state and state attribute changes."""
|
|
|
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
"""Create a schema."""
|
|
super().__init_subclass__(**kwargs)
|
|
cls._schema = make_numerical_state_changed_with_unit_schema(cls._unit_converter)
|
|
|
|
|
|
NUMERICAL_ATTRIBUTE_CROSSED_THRESHOLD_SCHEMA = (
|
|
ENTITY_STATE_TRIGGER_SCHEMA_FIRST_LAST.extend(
|
|
{
|
|
vol.Required(CONF_OPTIONS): {
|
|
vol.Required("threshold"): NumericThresholdSelector(
|
|
NumericThresholdSelectorConfig(mode=NumericThresholdMode.CROSSED)
|
|
),
|
|
},
|
|
}
|
|
)
|
|
)
|
|
|
|
|
|
class EntityNumericalStateCrossedThresholdTriggerBase(EntityNumericalStateTriggerBase):
|
|
"""Trigger for numerical state and state attribute changes.
|
|
|
|
This trigger only fires when the observed attribute changes from not within to within
|
|
the defined threshold.
|
|
"""
|
|
|
|
_schema = NUMERICAL_ATTRIBUTE_CROSSED_THRESHOLD_SCHEMA
|
|
|
|
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
|
|
"""Check if the origin state is valid and the state has changed."""
|
|
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
|
|
return False
|
|
|
|
return not self.is_valid_state(from_state)
|
|
|
|
|
|
def _make_numerical_state_crossed_threshold_with_unit_schema(
|
|
unit_converter: type[BaseUnitConverter],
|
|
) -> vol.Schema:
|
|
"""Trigger for numerical state and state attribute changes.
|
|
|
|
This trigger only fires when the observed attribute changes from not within to within
|
|
the defined threshold.
|
|
"""
|
|
return ENTITY_STATE_TRIGGER_SCHEMA_FIRST_LAST.extend(
|
|
{
|
|
vol.Required(CONF_OPTIONS, default={}): {
|
|
vol.Required("threshold"): NumericThresholdSelector(
|
|
NumericThresholdSelectorConfig(
|
|
mode=NumericThresholdMode.CROSSED,
|
|
unit_of_measurement=list(unit_converter.VALID_UNITS),
|
|
)
|
|
),
|
|
},
|
|
}
|
|
)
|
|
|
|
|
|
class EntityNumericalStateCrossedThresholdTriggerWithUnitBase(
|
|
EntityNumericalStateCrossedThresholdTriggerBase,
|
|
EntityNumericalStateTriggerWithUnitBase,
|
|
):
|
|
"""Trigger for numerical state and state attribute changes."""
|
|
|
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
"""Create a schema."""
|
|
super().__init_subclass__(**kwargs)
|
|
cls._schema = _make_numerical_state_crossed_threshold_with_unit_schema(
|
|
cls._unit_converter
|
|
)
|
|
|
|
|
|
def _normalize_domain_specs(
|
|
domain_specs: Mapping[str, DomainSpec] | str,
|
|
) -> Mapping[str, DomainSpec]:
|
|
"""Normalize domain_specs argument to a Mapping."""
|
|
if isinstance(domain_specs, str):
|
|
return {domain_specs: DomainSpec()}
|
|
return domain_specs
|
|
|
|
|
|
def make_entity_target_state_trigger(
|
|
domain_specs: Mapping[str, DomainSpec] | str,
|
|
to_states: str | set[str],
|
|
) -> type[EntityTargetStateTriggerBase]:
|
|
"""Create a trigger for entity state changes to specific state(s).
|
|
|
|
domain_specs can be a string (domain name) for simple state-based triggers,
|
|
or a Mapping[str, DomainSpec] for attribute-based or multi-domain triggers.
|
|
"""
|
|
specs = _normalize_domain_specs(domain_specs)
|
|
|
|
if isinstance(to_states, str):
|
|
to_states_set = {to_states}
|
|
else:
|
|
to_states_set = to_states
|
|
|
|
class CustomTrigger(EntityTargetStateTriggerBase):
|
|
"""Trigger for entity state changes."""
|
|
|
|
_domain_specs = specs
|
|
_to_states = to_states_set
|
|
|
|
return CustomTrigger
|
|
|
|
|
|
def make_entity_transition_trigger(
|
|
domain_specs: Mapping[str, DomainSpec] | str,
|
|
*,
|
|
from_states: set[str | bool],
|
|
to_states: set[str | bool],
|
|
) -> type[EntityTransitionTriggerBase]:
|
|
"""Create a trigger for entity state changes between specific states.
|
|
|
|
domain_specs can be a string (domain name) for simple state-based triggers,
|
|
or a Mapping[str, DomainSpec] for attribute-based or multi-domain triggers.
|
|
"""
|
|
specs = _normalize_domain_specs(domain_specs)
|
|
|
|
class CustomTrigger(EntityTransitionTriggerBase):
|
|
"""Trigger for conditional entity state changes."""
|
|
|
|
_domain_specs = specs
|
|
_from_states = from_states
|
|
_to_states = to_states
|
|
|
|
return CustomTrigger
|
|
|
|
|
|
def make_entity_origin_state_trigger(
|
|
domain_specs: Mapping[str, DomainSpec] | str,
|
|
*,
|
|
from_state: str,
|
|
) -> type[EntityOriginStateTriggerBase]:
|
|
"""Create a trigger for entity state changes from a specific state.
|
|
|
|
domain_specs can be a string (domain name) for simple state-based triggers,
|
|
or a Mapping[str, DomainSpec] for attribute-based or multi-domain triggers.
|
|
"""
|
|
specs = _normalize_domain_specs(domain_specs)
|
|
|
|
class CustomTrigger(EntityOriginStateTriggerBase):
|
|
"""Trigger for entity "from state" changes."""
|
|
|
|
_domain_specs = specs
|
|
_from_state = from_state
|
|
|
|
return CustomTrigger
|
|
|
|
|
|
def make_entity_numerical_state_changed_trigger(
|
|
domain_specs: Mapping[str, DomainSpec],
|
|
valid_unit: str | None | UndefinedType = UNDEFINED,
|
|
) -> type[EntityNumericalStateChangedTriggerBase]:
|
|
"""Create a trigger for numerical state value change."""
|
|
|
|
class CustomTrigger(EntityNumericalStateChangedTriggerBase):
|
|
"""Trigger for numerical state value changes."""
|
|
|
|
_domain_specs = domain_specs
|
|
_valid_unit = valid_unit
|
|
|
|
return CustomTrigger
|
|
|
|
|
|
def make_entity_numerical_state_crossed_threshold_trigger(
|
|
domain_specs: Mapping[str, DomainSpec],
|
|
valid_unit: str | None | UndefinedType = UNDEFINED,
|
|
) -> type[EntityNumericalStateCrossedThresholdTriggerBase]:
|
|
"""Create a trigger for numerical state value crossing a threshold."""
|
|
|
|
class CustomTrigger(EntityNumericalStateCrossedThresholdTriggerBase):
|
|
"""Trigger for numerical state value crossing a threshold."""
|
|
|
|
_domain_specs = domain_specs
|
|
_valid_unit = valid_unit
|
|
|
|
return CustomTrigger
|
|
|
|
|
|
def make_entity_numerical_state_changed_with_unit_trigger(
|
|
domain_specs: Mapping[str, DomainSpec],
|
|
base_unit: str,
|
|
unit_converter: type[BaseUnitConverter],
|
|
) -> type[EntityNumericalStateChangedTriggerWithUnitBase]:
|
|
"""Create a trigger for numerical state value change."""
|
|
|
|
class CustomTrigger(EntityNumericalStateChangedTriggerWithUnitBase):
|
|
"""Trigger for numerical state value changes."""
|
|
|
|
_domain_specs = domain_specs
|
|
_base_unit = base_unit
|
|
_unit_converter = unit_converter
|
|
|
|
return CustomTrigger
|
|
|
|
|
|
def make_entity_numerical_state_crossed_threshold_with_unit_trigger(
|
|
domain_specs: Mapping[str, DomainSpec],
|
|
base_unit: str,
|
|
unit_converter: type[BaseUnitConverter],
|
|
) -> type[EntityNumericalStateCrossedThresholdTriggerWithUnitBase]:
|
|
"""Create a trigger for numerical state value crossing a threshold."""
|
|
|
|
class CustomTrigger(EntityNumericalStateCrossedThresholdTriggerWithUnitBase):
|
|
"""Trigger for numerical state value crossing a threshold."""
|
|
|
|
_domain_specs = domain_specs
|
|
_base_unit = base_unit
|
|
_unit_converter = unit_converter
|
|
|
|
return CustomTrigger
|
|
|
|
|
|
class TriggerProtocol(Protocol):
|
|
"""Define the format of trigger modules.
|
|
|
|
New implementations should only implement async_get_triggers.
|
|
"""
|
|
|
|
async def async_get_triggers(self, hass: HomeAssistant) -> dict[str, type[Trigger]]:
|
|
"""Return the triggers provided by this integration."""
|
|
|
|
TRIGGER_SCHEMA: vol.Schema
|
|
|
|
async def async_validate_trigger_config(
|
|
self, hass: HomeAssistant, config: ConfigType
|
|
) -> ConfigType:
|
|
"""Validate config."""
|
|
|
|
async def async_attach_trigger(
|
|
self,
|
|
hass: HomeAssistant,
|
|
config: ConfigType,
|
|
action: TriggerActionType,
|
|
trigger_info: TriggerInfo,
|
|
) -> CALLBACK_TYPE:
|
|
"""Attach a trigger."""
|
|
|
|
|
|
@dataclass(slots=True, frozen=True)
|
|
class TriggerConfig:
|
|
"""Trigger config."""
|
|
|
|
key: str # The key used to identify the trigger, e.g. "zwave.event"
|
|
target: dict[str, Any] | None = None
|
|
options: dict[str, Any] | None = None
|
|
|
|
|
|
class TriggerActionRunner(Protocol):
|
|
"""Protocol type for the trigger action runner helper callback."""
|
|
|
|
@callback
|
|
def __call__(
|
|
self,
|
|
extra_trigger_payload: dict[str, Any],
|
|
description: str,
|
|
context: Context | None = None,
|
|
) -> asyncio.Task[Any]:
|
|
"""Define trigger action runner type.
|
|
|
|
Returns:
|
|
A Task that allows awaiting for the action to finish.
|
|
"""
|
|
|
|
|
|
class TriggerActionPayloadBuilder(Protocol):
|
|
"""Protocol type for the trigger action payload builder."""
|
|
|
|
def __call__(
|
|
self, extra_trigger_payload: dict[str, Any], description: str
|
|
) -> dict[str, Any]:
|
|
"""Define trigger action payload builder type."""
|
|
|
|
|
|
class TriggerAction(Protocol):
|
|
"""Protocol type for trigger action callback."""
|
|
|
|
async def __call__(
|
|
self, run_variables: dict[str, Any], context: Context | None = None
|
|
) -> Any:
|
|
"""Define action callback type."""
|
|
|
|
|
|
class TriggerActionType(Protocol):
|
|
"""Protocol type for trigger action callback.
|
|
|
|
Contrary to TriggerAction, this type supports both sync and async callables.
|
|
"""
|
|
|
|
def __call__(
|
|
self,
|
|
run_variables: dict[str, Any],
|
|
context: Context | None = None,
|
|
) -> Coroutine[Any, Any, Any] | Any:
|
|
"""Define action callback type."""
|
|
|
|
|
|
class TriggerData(TypedDict):
|
|
"""Trigger data."""
|
|
|
|
id: str
|
|
idx: str
|
|
alias: str | None
|
|
|
|
|
|
class TriggerInfo(TypedDict):
|
|
"""Information about trigger."""
|
|
|
|
domain: str
|
|
name: str
|
|
home_assistant_start: bool
|
|
variables: TemplateVarsType
|
|
trigger_data: TriggerData
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class PluggableActionsEntry:
|
|
"""Holder to keep track of all plugs and actions for a given trigger."""
|
|
|
|
plugs: set[PluggableAction] = field(default_factory=set)
|
|
actions: dict[
|
|
object,
|
|
tuple[
|
|
HassJob[[dict[str, Any], Context | None], Coroutine[Any, Any, None] | Any],
|
|
dict[str, Any],
|
|
],
|
|
] = field(default_factory=dict)
|
|
|
|
|
|
class PluggableAction:
|
|
"""A pluggable action handler."""
|
|
|
|
_entry: PluggableActionsEntry | None = None
|
|
|
|
def __init__(self, update: CALLBACK_TYPE | None = None) -> None:
|
|
"""Initialize a pluggable action.
|
|
|
|
:param update: callback triggered whenever triggers are attached or removed.
|
|
"""
|
|
self._update = update
|
|
|
|
def __bool__(self) -> bool:
|
|
"""Return if we have something attached."""
|
|
return bool(self._entry and self._entry.actions)
|
|
|
|
@callback
|
|
def async_run_update(self) -> None:
|
|
"""Run update function if one exists."""
|
|
if self._update:
|
|
self._update()
|
|
|
|
@staticmethod
|
|
@callback
|
|
def async_get_registry(hass: HomeAssistant) -> dict[tuple, PluggableActionsEntry]:
|
|
"""Return the pluggable actions registry."""
|
|
if data := hass.data.get(DATA_PLUGGABLE_ACTIONS):
|
|
return data
|
|
data = hass.data[DATA_PLUGGABLE_ACTIONS] = defaultdict(PluggableActionsEntry)
|
|
return data
|
|
|
|
@staticmethod
|
|
@callback
|
|
def async_attach_trigger(
|
|
hass: HomeAssistant,
|
|
trigger: dict[str, str],
|
|
action: TriggerActionType,
|
|
variables: dict[str, Any],
|
|
) -> CALLBACK_TYPE:
|
|
"""Attach an action to a trigger entry.
|
|
|
|
Existing or future plugs registered will be attached.
|
|
"""
|
|
reg = PluggableAction.async_get_registry(hass)
|
|
key = tuple(sorted(trigger.items()))
|
|
entry = reg[key]
|
|
|
|
def _update() -> None:
|
|
for plug in entry.plugs:
|
|
plug.async_run_update()
|
|
|
|
@callback
|
|
def _remove() -> None:
|
|
"""Remove this action attachment, and disconnect all plugs."""
|
|
del entry.actions[_remove]
|
|
_update()
|
|
if not entry.actions and not entry.plugs:
|
|
del reg[key]
|
|
|
|
job = HassJob(action, f"trigger {trigger} {variables}")
|
|
entry.actions[_remove] = (job, variables)
|
|
_update()
|
|
|
|
return _remove
|
|
|
|
@callback
|
|
def async_register(
|
|
self, hass: HomeAssistant, trigger: dict[str, str]
|
|
) -> CALLBACK_TYPE:
|
|
"""Register plug in the global plugs dictionary."""
|
|
|
|
reg = PluggableAction.async_get_registry(hass)
|
|
key = tuple(sorted(trigger.items()))
|
|
self._entry = reg[key]
|
|
self._entry.plugs.add(self)
|
|
|
|
@callback
|
|
def _remove() -> None:
|
|
"""Remove plug from registration.
|
|
|
|
Clean up entry if there are no actions or plugs registered.
|
|
"""
|
|
assert self._entry
|
|
self._entry.plugs.remove(self)
|
|
if not self._entry.actions and not self._entry.plugs:
|
|
del reg[key]
|
|
self._entry = None
|
|
|
|
return _remove
|
|
|
|
async def async_run(
|
|
self, hass: HomeAssistant, context: Context | None = None
|
|
) -> None:
|
|
"""Run all actions."""
|
|
assert self._entry
|
|
for job, variables in self._entry.actions.values():
|
|
task = hass.async_run_hass_job(job, variables, context)
|
|
if task:
|
|
await task
|
|
|
|
|
|
async def _async_get_trigger_platform(
|
|
hass: HomeAssistant, trigger_key: str
|
|
) -> tuple[str, TriggerProtocol]:
|
|
from homeassistant.components import automation # noqa: PLC0415
|
|
|
|
platform_and_sub_type = trigger_key.split(".")
|
|
platform = platform_and_sub_type[0]
|
|
# Only apply aliases for old-style triggers (no sub_type).
|
|
# New-style triggers (e.g. "event.received") use the integration domain directly.
|
|
if len(platform_and_sub_type) == 1:
|
|
platform = _PLATFORM_ALIASES.get(platform, platform)
|
|
|
|
if automation.is_disabled_experimental_trigger(hass, platform):
|
|
raise vol.Invalid(
|
|
f"Trigger '{trigger_key}' requires the experimental 'New triggers and "
|
|
"conditions' feature to be enabled in Home Assistant Labs settings "
|
|
f"(feature flag: '{automation.NEW_TRIGGERS_CONDITIONS_FEATURE_FLAG}')"
|
|
)
|
|
|
|
try:
|
|
integration = await async_get_integration(hass, platform)
|
|
except IntegrationNotFound:
|
|
raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified") from None
|
|
try:
|
|
platform_module = await integration.async_get_platform("trigger")
|
|
except ImportError:
|
|
raise vol.Invalid(
|
|
f"Integration '{platform}' does not provide trigger support"
|
|
) from None
|
|
|
|
# Ensure triggers are registered so descriptions can be loaded
|
|
await _register_trigger_platform(hass, platform, platform_module)
|
|
|
|
return platform, platform_module
|
|
|
|
|
|
async def async_validate_trigger_config(
|
|
hass: HomeAssistant, trigger_config: list[ConfigType]
|
|
) -> list[ConfigType]:
|
|
"""Validate triggers."""
|
|
config = []
|
|
for conf in trigger_config:
|
|
trigger_key: str = conf[CONF_PLATFORM]
|
|
platform_domain, platform = await _async_get_trigger_platform(hass, trigger_key)
|
|
if hasattr(platform, "async_get_triggers"):
|
|
trigger_descriptors = await platform.async_get_triggers(hass)
|
|
relative_trigger_key = get_relative_description_key(
|
|
platform_domain, trigger_key
|
|
)
|
|
if not (trigger := trigger_descriptors.get(relative_trigger_key)):
|
|
raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified")
|
|
conf = await trigger.async_validate_complete_config(hass, conf)
|
|
elif hasattr(platform, "async_validate_trigger_config"):
|
|
conf = move_options_fields_to_top_level(conf, cv.TRIGGER_BASE_SCHEMA)
|
|
conf = await platform.async_validate_trigger_config(hass, conf)
|
|
else:
|
|
conf = move_options_fields_to_top_level(conf, cv.TRIGGER_BASE_SCHEMA)
|
|
conf = platform.TRIGGER_SCHEMA(conf)
|
|
config.append(conf)
|
|
return config
|
|
|
|
|
|
def _trigger_action_wrapper(
|
|
hass: HomeAssistant, action: Callable, conf: ConfigType
|
|
) -> Callable:
|
|
"""Wrap trigger action with extra vars if configured.
|
|
|
|
If action is a coroutine function, a coroutine function will be returned.
|
|
If action is a callback, a callback will be returned.
|
|
"""
|
|
if CONF_VARIABLES not in conf:
|
|
return action
|
|
|
|
# Check for partials to properly determine if coroutine function
|
|
check_func = action
|
|
while isinstance(check_func, functools.partial):
|
|
check_func = check_func.func
|
|
|
|
wrapper_func: Callable[..., Any] | Callable[..., Coroutine[Any, Any, Any]]
|
|
if inspect.iscoroutinefunction(check_func):
|
|
async_action = cast(Callable[..., Coroutine[Any, Any, Any]], action)
|
|
|
|
@functools.wraps(async_action)
|
|
async def async_with_vars(
|
|
run_variables: dict[str, Any], context: Context | None = None
|
|
) -> Any:
|
|
"""Wrap action with extra vars."""
|
|
trigger_variables = conf[CONF_VARIABLES]
|
|
run_variables.update(trigger_variables.async_render(hass, run_variables))
|
|
return await action(run_variables, context)
|
|
|
|
wrapper_func = async_with_vars
|
|
|
|
else:
|
|
|
|
@functools.wraps(action)
|
|
def with_vars(
|
|
run_variables: dict[str, Any], context: Context | None = None
|
|
) -> Any:
|
|
"""Wrap action with extra vars."""
|
|
trigger_variables = conf[CONF_VARIABLES]
|
|
run_variables.update(trigger_variables.async_render(hass, run_variables))
|
|
return action(run_variables, context)
|
|
|
|
if is_callback(check_func):
|
|
with_vars = callback(with_vars)
|
|
|
|
wrapper_func = with_vars
|
|
|
|
return wrapper_func
|
|
|
|
|
|
async def _async_attach_trigger_cls(
|
|
hass: HomeAssistant,
|
|
trigger_cls: type[Trigger],
|
|
trigger_key: str,
|
|
conf: ConfigType,
|
|
action: Callable,
|
|
trigger_info: TriggerInfo,
|
|
) -> CALLBACK_TYPE:
|
|
"""Initialize a new Trigger class and attach it."""
|
|
|
|
def action_payload_builder(
|
|
extra_trigger_payload: dict[str, Any], description: str
|
|
) -> dict[str, Any]:
|
|
"""Build action variables."""
|
|
payload = {
|
|
"trigger": {
|
|
**trigger_info["trigger_data"],
|
|
CONF_PLATFORM: trigger_key,
|
|
"description": description,
|
|
**extra_trigger_payload,
|
|
}
|
|
}
|
|
if CONF_VARIABLES in conf:
|
|
trigger_variables = conf[CONF_VARIABLES]
|
|
payload.update(trigger_variables.async_render(hass, payload))
|
|
return payload
|
|
|
|
# Wrap sync action so that it is always async.
|
|
# This simplifies the Trigger action runner interface by always returning a coroutine,
|
|
# removing the need for integrations to check for the return type when awaiting the action.
|
|
match get_hassjob_callable_job_type(action):
|
|
case HassJobType.Executor:
|
|
original_action = action
|
|
|
|
async def wrapped_executor_action(
|
|
run_variables: dict[str, Any], context: Context | None = None
|
|
) -> Any:
|
|
"""Wrap sync action to be called in executor."""
|
|
return await hass.async_add_executor_job(
|
|
original_action, run_variables, context
|
|
)
|
|
|
|
action = wrapped_executor_action
|
|
|
|
case HassJobType.Callback:
|
|
original_action = action
|
|
|
|
async def wrapped_callback_action(
|
|
run_variables: dict[str, Any], context: Context | None = None
|
|
) -> Any:
|
|
"""Wrap callback action to be awaitable."""
|
|
return original_action(run_variables, context)
|
|
|
|
action = wrapped_callback_action
|
|
|
|
trigger = trigger_cls(
|
|
hass,
|
|
TriggerConfig(
|
|
key=trigger_key,
|
|
target=conf.get(CONF_TARGET),
|
|
options=conf.get(CONF_OPTIONS),
|
|
),
|
|
)
|
|
return await trigger.async_attach_action(action, action_payload_builder)
|
|
|
|
|
|
async def async_initialize_triggers(
|
|
hass: HomeAssistant,
|
|
trigger_config: list[ConfigType],
|
|
action: Callable,
|
|
domain: str,
|
|
name: str,
|
|
log_cb: Callable,
|
|
home_assistant_start: bool = False,
|
|
variables: TemplateVarsType = None,
|
|
) -> CALLBACK_TYPE | None:
|
|
"""Initialize triggers."""
|
|
triggers: list[asyncio.Task[CALLBACK_TYPE]] = []
|
|
for idx, conf in enumerate(trigger_config):
|
|
# Skip triggers that are not enabled
|
|
if CONF_ENABLED in conf:
|
|
enabled = conf[CONF_ENABLED]
|
|
if isinstance(enabled, Template):
|
|
try:
|
|
enabled = enabled.async_render(variables, limited=True)
|
|
except TemplateError as err:
|
|
log_cb(logging.ERROR, f"Error rendering enabled template: {err}")
|
|
continue
|
|
if not enabled:
|
|
continue
|
|
|
|
trigger_key: str = conf[CONF_PLATFORM]
|
|
platform_domain, platform = await _async_get_trigger_platform(hass, trigger_key)
|
|
trigger_id = conf.get(CONF_ID, f"{idx}")
|
|
trigger_idx = f"{idx}"
|
|
trigger_alias = conf.get(CONF_ALIAS)
|
|
trigger_data = TriggerData(id=trigger_id, idx=trigger_idx, alias=trigger_alias)
|
|
info = TriggerInfo(
|
|
domain=domain,
|
|
name=name,
|
|
home_assistant_start=home_assistant_start,
|
|
variables=variables,
|
|
trigger_data=trigger_data,
|
|
)
|
|
|
|
if hasattr(platform, "async_get_triggers"):
|
|
trigger_descriptors = await platform.async_get_triggers(hass)
|
|
relative_trigger_key = get_relative_description_key(
|
|
platform_domain, trigger_key
|
|
)
|
|
trigger_cls = trigger_descriptors[relative_trigger_key]
|
|
coro = _async_attach_trigger_cls(
|
|
hass, trigger_cls, trigger_key, conf, action, info
|
|
)
|
|
else:
|
|
action_wrapper = _trigger_action_wrapper(hass, action, conf)
|
|
coro = platform.async_attach_trigger(hass, conf, action_wrapper, info)
|
|
|
|
triggers.append(create_eager_task(coro))
|
|
|
|
attach_results = await asyncio.gather(*triggers, return_exceptions=True)
|
|
removes: list[Callable[[], None]] = []
|
|
|
|
for result in attach_results:
|
|
if isinstance(result, HomeAssistantError):
|
|
log_cb(logging.ERROR, f"Got error '{result}' when setting up triggers for")
|
|
elif isinstance(result, Exception):
|
|
log_cb(logging.ERROR, "Error setting up trigger", exc_info=result)
|
|
elif isinstance(result, BaseException):
|
|
raise result from None
|
|
elif result is None:
|
|
log_cb( # type: ignore[unreachable]
|
|
logging.ERROR, "Unknown error while setting up trigger (empty result)"
|
|
)
|
|
else:
|
|
removes.append(result)
|
|
|
|
if not removes:
|
|
return None
|
|
|
|
log_cb(logging.INFO, "Initialized trigger")
|
|
|
|
@callback
|
|
def remove_triggers() -> None:
|
|
"""Remove triggers."""
|
|
for remove in removes:
|
|
remove()
|
|
|
|
return remove_triggers
|
|
|
|
|
|
def _load_triggers_file(integration: Integration) -> dict[str, Any]:
|
|
"""Load triggers file for an integration."""
|
|
try:
|
|
return cast(
|
|
dict[str, Any],
|
|
_TRIGGERS_DESCRIPTION_SCHEMA(
|
|
load_yaml_dict(str(integration.file_path / "triggers.yaml"))
|
|
),
|
|
)
|
|
except FileNotFoundError:
|
|
_LOGGER.warning(
|
|
"Unable to find triggers.yaml for the %s integration", integration.domain
|
|
)
|
|
return {}
|
|
except (HomeAssistantError, vol.Invalid) as ex:
|
|
_LOGGER.warning(
|
|
"Unable to parse triggers.yaml for the %s integration: %s",
|
|
integration.domain,
|
|
ex,
|
|
)
|
|
return {}
|
|
|
|
|
|
def _load_triggers_files(
|
|
integrations: Iterable[Integration],
|
|
) -> dict[str, dict[str, Any]]:
|
|
"""Load trigger files for multiple integrations."""
|
|
return {
|
|
integration.domain: {
|
|
get_absolute_description_key(integration.domain, key): value
|
|
for key, value in _load_triggers_file(integration).items()
|
|
}
|
|
for integration in integrations
|
|
}
|
|
|
|
|
|
async def async_get_all_descriptions(
|
|
hass: HomeAssistant,
|
|
) -> dict[str, dict[str, Any] | None]:
|
|
"""Return descriptions (i.e. user documentation) for all triggers."""
|
|
from homeassistant.components import automation # noqa: PLC0415
|
|
|
|
descriptions_cache = hass.data[TRIGGER_DESCRIPTION_CACHE]
|
|
|
|
triggers = hass.data[TRIGGERS]
|
|
# See if there are new triggers not seen before.
|
|
# Any trigger that we saw before already has an entry in description_cache.
|
|
all_triggers = set(triggers)
|
|
previous_all_triggers = set(descriptions_cache)
|
|
# If the triggers are the same, we can return the cache
|
|
|
|
# mypy complains: Invalid index type "HassKey[set[str]]" for "HassDict"
|
|
if previous_all_triggers | hass.data[TRIGGER_DISABLED_TRIGGERS] == all_triggers: # type: ignore[index]
|
|
return descriptions_cache
|
|
|
|
# Files we loaded for missing descriptions
|
|
new_triggers_descriptions: dict[str, dict[str, Any]] = {}
|
|
# We try to avoid making a copy in the event the cache is good,
|
|
# but now we must make a copy in case new triggers get added
|
|
# while we are loading the missing ones so we do not
|
|
# add the new ones to the cache without their descriptions
|
|
triggers = triggers.copy()
|
|
|
|
if missing_triggers := all_triggers.difference(descriptions_cache):
|
|
domains_with_missing_triggers = {
|
|
triggers[missing_trigger] for missing_trigger in missing_triggers
|
|
}
|
|
ints_or_excs = await async_get_integrations(hass, domains_with_missing_triggers)
|
|
integrations: list[Integration] = []
|
|
for domain, int_or_exc in ints_or_excs.items():
|
|
if type(int_or_exc) is Integration and int_or_exc.has_triggers:
|
|
integrations.append(int_or_exc)
|
|
continue
|
|
if TYPE_CHECKING:
|
|
assert isinstance(int_or_exc, Exception)
|
|
_LOGGER.debug(
|
|
"Failed to load triggers.yaml for integration: %s",
|
|
domain,
|
|
exc_info=int_or_exc,
|
|
)
|
|
|
|
if integrations:
|
|
new_triggers_descriptions = await hass.async_add_executor_job(
|
|
_load_triggers_files, integrations
|
|
)
|
|
|
|
# Make a copy of the old cache and add missing descriptions to it
|
|
new_descriptions_cache = descriptions_cache.copy()
|
|
for missing_trigger in missing_triggers:
|
|
domain = triggers[missing_trigger]
|
|
if automation.is_disabled_experimental_trigger(hass, domain):
|
|
hass.data[TRIGGER_DISABLED_TRIGGERS].add(missing_trigger)
|
|
continue
|
|
|
|
if (
|
|
yaml_description := new_triggers_descriptions.get(domain, {}).get(
|
|
missing_trigger
|
|
)
|
|
) is None:
|
|
_LOGGER.debug(
|
|
"No trigger descriptions found for trigger %s, skipping",
|
|
missing_trigger,
|
|
)
|
|
new_descriptions_cache[missing_trigger] = None
|
|
continue
|
|
|
|
description = {"fields": yaml_description.get("fields", {})}
|
|
if (target := yaml_description.get("target")) is not None:
|
|
description["target"] = target
|
|
|
|
new_descriptions_cache[missing_trigger] = description
|
|
hass.data[TRIGGER_DESCRIPTION_CACHE] = new_descriptions_cache
|
|
return new_descriptions_cache
|
|
|
|
|
|
@callback
|
|
def async_extract_devices(trigger_conf: dict) -> list[str]:
|
|
"""Extract devices from a trigger config."""
|
|
if trigger_conf[CONF_PLATFORM] == "device":
|
|
return [trigger_conf[CONF_DEVICE_ID]]
|
|
|
|
if (
|
|
trigger_conf[CONF_PLATFORM] == "event"
|
|
and CONF_EVENT_DATA in trigger_conf
|
|
and CONF_DEVICE_ID in trigger_conf[CONF_EVENT_DATA]
|
|
and isinstance(trigger_conf[CONF_EVENT_DATA][CONF_DEVICE_ID], str)
|
|
):
|
|
return [trigger_conf[CONF_EVENT_DATA][CONF_DEVICE_ID]]
|
|
|
|
if trigger_conf[CONF_PLATFORM] == "tag" and CONF_DEVICE_ID in trigger_conf:
|
|
return trigger_conf[CONF_DEVICE_ID] # type: ignore[no-any-return]
|
|
|
|
if target_devices := async_extract_targets(trigger_conf, CONF_DEVICE_ID):
|
|
return target_devices
|
|
|
|
return []
|
|
|
|
|
|
@callback
|
|
def async_extract_entities(trigger_conf: dict) -> list[str]:
|
|
"""Extract entities from a trigger config."""
|
|
if trigger_conf[CONF_PLATFORM] in ("state", "numeric_state"):
|
|
return trigger_conf[CONF_ENTITY_ID] # type: ignore[no-any-return]
|
|
|
|
if trigger_conf[CONF_PLATFORM] == "calendar":
|
|
return [trigger_conf[CONF_OPTIONS][CONF_ENTITY_ID]]
|
|
|
|
if trigger_conf[CONF_PLATFORM] == "zone":
|
|
return trigger_conf[CONF_ENTITY_ID] + [trigger_conf[CONF_ZONE]] # type: ignore[no-any-return]
|
|
|
|
if trigger_conf[CONF_PLATFORM] == "geo_location":
|
|
return [trigger_conf[CONF_ZONE]]
|
|
|
|
if trigger_conf[CONF_PLATFORM] == "sun":
|
|
return ["sun.sun"]
|
|
|
|
if (
|
|
trigger_conf[CONF_PLATFORM] == "event"
|
|
and CONF_EVENT_DATA in trigger_conf
|
|
and CONF_ENTITY_ID in trigger_conf[CONF_EVENT_DATA]
|
|
and isinstance(trigger_conf[CONF_EVENT_DATA][CONF_ENTITY_ID], str)
|
|
and valid_entity_id(trigger_conf[CONF_EVENT_DATA][CONF_ENTITY_ID])
|
|
):
|
|
return [trigger_conf[CONF_EVENT_DATA][CONF_ENTITY_ID]]
|
|
|
|
if target_entities := async_extract_targets(trigger_conf, CONF_ENTITY_ID):
|
|
return target_entities
|
|
|
|
return []
|
|
|
|
|
|
@callback
|
|
def async_extract_targets(
|
|
config: dict,
|
|
target: Literal["entity_id", "device_id", "area_id", "floor_id", "label_id"],
|
|
) -> list[str]:
|
|
"""Extract targets from a target config."""
|
|
if not (target_conf := config.get(CONF_TARGET)):
|
|
return []
|
|
if not (targets := target_conf.get(target)):
|
|
return []
|
|
|
|
return [targets] if isinstance(targets, str) else targets
|