diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 5ce6fca7c2e..b173fa5b248 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -123,6 +123,7 @@ SERVICE_TRIGGER = "trigger" NEW_TRIGGERS_CONDITIONS_FEATURE_FLAG = "new_triggers_conditions" _EXPERIMENTAL_CONDITION_PLATFORMS = { + "fan", "light", } diff --git a/homeassistant/components/fan/condition.py b/homeassistant/components/fan/condition.py new file mode 100644 index 00000000000..2063e98033e --- /dev/null +++ b/homeassistant/components/fan/condition.py @@ -0,0 +1,17 @@ +"""Provides conditions for fans.""" + +from homeassistant.const import STATE_OFF, STATE_ON +from homeassistant.core import HomeAssistant +from homeassistant.helpers.condition import Condition, make_entity_state_condition + +from . import DOMAIN + +CONDITIONS: dict[str, type[Condition]] = { + "is_off": make_entity_state_condition(DOMAIN, STATE_OFF), + "is_on": make_entity_state_condition(DOMAIN, STATE_ON), +} + + +async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]: + """Return the fan conditions.""" + return CONDITIONS diff --git a/homeassistant/components/fan/conditions.yaml b/homeassistant/components/fan/conditions.yaml new file mode 100644 index 00000000000..2f7e4fca5b9 --- /dev/null +++ b/homeassistant/components/fan/conditions.yaml @@ -0,0 +1,17 @@ +.condition_common: &condition_common + target: + entity: + domain: fan + fields: + behavior: + required: true + default: any + selector: + select: + translation_key: condition_behavior + options: + - all + - any + +is_off: *condition_common +is_on: *condition_common diff --git a/homeassistant/components/fan/icons.json b/homeassistant/components/fan/icons.json index 9f52b55bf7d..91a1924056f 100644 --- a/homeassistant/components/fan/icons.json +++ b/homeassistant/components/fan/icons.json @@ -1,4 +1,12 @@ { + "conditions": { + "is_off": { + "condition": "mdi:fan-off" + }, + "is_on": { + "condition": "mdi:fan" + } + }, "entity_component": { "_": { "default": "mdi:fan", diff --git a/homeassistant/components/fan/strings.json b/homeassistant/components/fan/strings.json index a6e4b91c65e..ba6df0a288f 100644 --- a/homeassistant/components/fan/strings.json +++ b/homeassistant/components/fan/strings.json @@ -1,8 +1,32 @@ { "common": { + "condition_behavior_description": "How the state should match on the targeted fans.", + "condition_behavior_name": "Behavior", "trigger_behavior_description": "The behavior of the targeted fans to trigger on.", "trigger_behavior_name": "Behavior" }, + "conditions": { + "is_off": { + "description": "Tests if one or more fans are off.", + "fields": { + "behavior": { + "description": "[%key:component::fan::common::condition_behavior_description%]", + "name": "[%key:component::fan::common::condition_behavior_name%]" + } + }, + "name": "If a fan is off" + }, + "is_on": { + "description": "Tests if one or more fans are on.", + "fields": { + "behavior": { + "description": "[%key:component::fan::common::condition_behavior_description%]", + "name": "[%key:component::fan::common::condition_behavior_name%]" + } + }, + "name": "If a fan is on" + } + }, "device_automation": { "action_type": { "toggle": "[%key:common::device_automation::action_type::toggle%]", @@ -65,6 +89,12 @@ } }, "selector": { + "condition_behavior": { + "options": { + "all": "All", + "any": "Any" + } + }, "direction": { "options": { "forward": "Forward", diff --git a/homeassistant/components/light/condition.py b/homeassistant/components/light/condition.py index 2b2ac0acca6..59fcd10c831 100644 --- a/homeassistant/components/light/condition.py +++ b/homeassistant/components/light/condition.py @@ -1,126 +1,14 @@ """Provides conditions for lights.""" -from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Final, Unpack, override - -import voluptuous as vol - -from homeassistant.const import CONF_OPTIONS, CONF_TARGET, STATE_OFF, STATE_ON -from homeassistant.core import HomeAssistant, split_entity_id -from homeassistant.helpers import config_validation as cv, target -from homeassistant.helpers.condition import ( - Condition, - ConditionChecker, - ConditionCheckParams, - ConditionConfig, -) -from homeassistant.helpers.typing import ConfigType +from homeassistant.const import STATE_OFF, STATE_ON +from homeassistant.core import HomeAssistant +from homeassistant.helpers.condition import Condition, make_entity_state_condition from .const import DOMAIN -ATTR_BEHAVIOR: Final = "behavior" -BEHAVIOR_ANY: Final = "any" -BEHAVIOR_ALL: Final = "all" - - -STATE_CONDITION_VALID_STATES: Final = [STATE_ON, STATE_OFF] -STATE_CONDITION_OPTIONS_SCHEMA: dict[vol.Marker, Any] = { - vol.Required(ATTR_BEHAVIOR, default=BEHAVIOR_ANY): vol.In( - [BEHAVIOR_ANY, BEHAVIOR_ALL] - ), -} -STATE_CONDITION_SCHEMA = vol.Schema( - { - vol.Required(CONF_TARGET): cv.TARGET_FIELDS, - vol.Required(CONF_OPTIONS): STATE_CONDITION_OPTIONS_SCHEMA, - } -) - - -class StateConditionBase(Condition): - """State condition.""" - - @override - @classmethod - async def async_validate_config( - cls, hass: HomeAssistant, config: ConfigType - ) -> ConfigType: - """Validate config.""" - return STATE_CONDITION_SCHEMA(config) # type: ignore[no-any-return] - - def __init__( - self, hass: HomeAssistant, config: ConditionConfig, state: str - ) -> None: - """Initialize condition.""" - super().__init__(hass, config) - if TYPE_CHECKING: - assert config.target - assert config.options - self._target_selection = target.TargetSelection(config.target) - self._behavior = config.options[ATTR_BEHAVIOR] - self._state = state - - @override - async def async_get_checker(self) -> ConditionChecker: - """Get the condition checker.""" - - def check_any_match_state(states: list[str]) -> bool: - """Test if any entity match the state.""" - return any(state == self._state for state in states) - - def check_all_match_state(states: list[str]) -> bool: - """Test if all entities match the state.""" - return all(state == self._state for state in states) - - matcher: Callable[[list[str]], bool] - if self._behavior == BEHAVIOR_ANY: - matcher = check_any_match_state - elif self._behavior == BEHAVIOR_ALL: - matcher = check_all_match_state - - def test_state(**kwargs: Unpack[ConditionCheckParams]) -> bool: - """Test state condition.""" - targeted_entities = target.async_extract_referenced_entity_ids( - self._hass, self._target_selection, expand_group=False - ) - referenced_entity_ids = targeted_entities.referenced.union( - targeted_entities.indirectly_referenced - ) - light_entity_ids = { - entity_id - for entity_id in referenced_entity_ids - if split_entity_id(entity_id)[0] == DOMAIN - } - light_entity_states = [ - state.state - for entity_id in light_entity_ids - if (state := self._hass.states.get(entity_id)) - and state.state in STATE_CONDITION_VALID_STATES - ] - return matcher(light_entity_states) - - return test_state - - -class IsOnCondition(StateConditionBase): - """Is on condition.""" - - def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None: - """Initialize condition.""" - super().__init__(hass, config, STATE_ON) - - -class IsOffCondition(StateConditionBase): - """Is off condition.""" - - def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None: - """Initialize condition.""" - super().__init__(hass, config, STATE_OFF) - - CONDITIONS: dict[str, type[Condition]] = { - "is_off": IsOffCondition, - "is_on": IsOnCondition, + "is_off": make_entity_state_condition(DOMAIN, STATE_OFF), + "is_on": make_entity_state_condition(DOMAIN, STATE_ON), } diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index 957ff25434f..6e9f0e55365 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -13,7 +13,17 @@ import inspect import logging import re import sys -from typing import TYPE_CHECKING, Any, Protocol, TypedDict, Unpack, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Final, + Protocol, + TypedDict, + Unpack, + cast, + overload, + override, +) import voluptuous as vol @@ -43,7 +53,7 @@ from homeassistant.const import ( STATE_UNKNOWN, WEEKDAYS, ) -from homeassistant.core import HomeAssistant, State, callback +from homeassistant.core import HomeAssistant, State, callback, split_entity_id from homeassistant.exceptions import ( ConditionError, ConditionErrorContainer, @@ -71,6 +81,7 @@ from .automation import ( ) from .integration_platform import async_process_integration_platforms from .selector import TargetSelector +from .target import TargetSelection, async_extract_referenced_entity_ids from .template import Template, render_complex from .trace import ( TraceElement, @@ -302,6 +313,112 @@ class Condition(abc.ABC): """Get the condition checker.""" +ATTR_BEHAVIOR: Final = "behavior" +BEHAVIOR_ANY: Final = "any" +BEHAVIOR_ALL: Final = "all" + +STATE_CONDITION_OPTIONS_SCHEMA: dict[vol.Marker, Any] = { + vol.Required(ATTR_BEHAVIOR, default=BEHAVIOR_ANY): vol.In( + [BEHAVIOR_ANY, BEHAVIOR_ALL] + ), +} +ENTITY_STATE_CONDITION_SCHEMA_ANY_ALL = vol.Schema( + { + vol.Required(CONF_TARGET): cv.TARGET_FIELDS, + vol.Required(CONF_OPTIONS): STATE_CONDITION_OPTIONS_SCHEMA, + } +) + + +class EntityStateConditionBase(Condition): + """State condition.""" + + _domain: str + _schema: vol.Schema = ENTITY_STATE_CONDITION_SCHEMA_ANY_ALL + _states: set[str] + + @override + @classmethod + async def async_validate_config( + cls, hass: HomeAssistant, config: ConfigType + ) -> ConfigType: + """Validate config.""" + return cast(ConfigType, cls._schema(config)) + + def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None: + """Initialize condition.""" + super().__init__(hass, config) + if TYPE_CHECKING: + assert config.target + assert config.options + self._target_selection = TargetSelection(config.target) + self._behavior = config.options[ATTR_BEHAVIOR] + + def entity_filter(self, entities: set[str]) -> set[str]: + """Filter entities of this domain.""" + return { + entity_id + for entity_id in entities + if split_entity_id(entity_id)[0] == self._domain + } + + @override + async def async_get_checker(self) -> ConditionChecker: + """Get the condition checker.""" + + def check_any_match_state(states: list[str]) -> bool: + """Test if any entity match the state.""" + return any(state in self._states for state in states) + + def check_all_match_state(states: list[str]) -> bool: + """Test if all entities match the state.""" + return all(state in self._states for state in states) + + matcher: Callable[[list[str]], bool] + if self._behavior == BEHAVIOR_ANY: + matcher = check_any_match_state + elif self._behavior == BEHAVIOR_ALL: + matcher = check_all_match_state + + def test_state(**kwargs: Unpack[ConditionCheckParams]) -> bool: + """Test state condition.""" + targeted_entities = async_extract_referenced_entity_ids( + self._hass, self._target_selection, expand_group=False + ) + referenced_entity_ids = targeted_entities.referenced.union( + targeted_entities.indirectly_referenced + ) + filtered_entity_ids = self.entity_filter(referenced_entity_ids) + entity_states = [ + _state.state + for entity_id in filtered_entity_ids + if (_state := self._hass.states.get(entity_id)) + and _state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN) + ] + return matcher(entity_states) + + return test_state + + +def make_entity_state_condition( + domain: str, states: str | set[str] +) -> type[EntityStateConditionBase]: + """Create a condition for entity state changes to specific state(s).""" + + if isinstance(states, str): + states_set = {states} + else: + states_set = states + + class CustomCondition(EntityStateConditionBase): + """Condition for entity state.""" + + _domain = domain + _states = states_set + + return CustomCondition + + class ConditionProtocol(Protocol): """Define the format of condition modules.""" diff --git a/tests/components/fan/test_condition.py b/tests/components/fan/test_condition.py new file mode 100644 index 00000000000..23b520332b7 --- /dev/null +++ b/tests/components/fan/test_condition.py @@ -0,0 +1,261 @@ +"""Test fan conditions.""" + +from collections.abc import Generator +from typing import Any +from unittest.mock import patch + +import pytest + +from homeassistant.components import automation +from homeassistant.const import ( + ATTR_LABEL_ID, + CONF_CONDITION, + CONF_OPTIONS, + CONF_TARGET, + STATE_OFF, + STATE_ON, +) +from homeassistant.core import HomeAssistant +from homeassistant.helpers.condition import ( + ConditionCheckerTypeOptional, + async_from_config, +) +from homeassistant.setup import async_setup_component + +from tests.components import ( + ConditionStateDescription, + parametrize_condition_states, + parametrize_target_entities, + set_or_remove_state, + target_entities, +) + + +@pytest.fixture(autouse=True, name="stub_blueprint_populate") +def stub_blueprint_populate_autouse(stub_blueprint_populate: None) -> None: + """Stub copying the blueprints to the config folder.""" + + +@pytest.fixture +async def target_fans(hass: HomeAssistant) -> list[str]: + """Create multiple fan entities associated with different targets.""" + return (await target_entities(hass, "fan"))["included"] + + +@pytest.fixture +async def target_switches(hass: HomeAssistant) -> list[str]: + """Create multiple switch entities associated with different targets.""" + return (await target_entities(hass, "switch"))["included"] + + +async def setup_automation_with_fan_condition( + hass: HomeAssistant, + *, + condition: str, + target: dict, + behavior: str, +) -> None: + """Set up automation with fan state condition.""" + await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "trigger": {"platform": "event", "event_type": "test_event"}, + "condition": { + CONF_CONDITION: condition, + CONF_TARGET: target, + CONF_OPTIONS: {"behavior": behavior}, + }, + "action": { + "service": "test.automation", + }, + } + }, + ) + + +async def create_condition( + hass: HomeAssistant, + *, + condition: str, + target: dict, + behavior: str, +) -> ConditionCheckerTypeOptional: + """Create a fan state condition.""" + return await async_from_config( + hass, + { + CONF_CONDITION: condition, + CONF_TARGET: target, + CONF_OPTIONS: {"behavior": behavior}, + }, + ) + + +@pytest.fixture(name="enable_experimental_triggers_conditions") +def enable_experimental_triggers_conditions() -> Generator[None]: + """Enable experimental triggers and conditions.""" + with patch( + "homeassistant.components.labs.async_is_preview_feature_enabled", + return_value=True, + ): + yield + + +@pytest.mark.parametrize( + "condition", + [ + "fan.is_off", + "fan.is_on", + ], +) +async def test_fan_conditions_gated_by_labs_flag( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture, condition: str +) -> None: + """Test the fan conditions are gated by the labs flag.""" + await setup_automation_with_fan_condition( + hass, condition=condition, target={ATTR_LABEL_ID: "test_label"}, behavior="any" + ) + assert ( + "Unnamed automation failed to setup conditions and has been disabled: " + f"Condition '{condition}' requires the experimental 'New triggers and " + "conditions' feature to be enabled in Home Assistant Labs settings " + "(feature flag: 'new_triggers_conditions')" + ) in caplog.text + + +@pytest.mark.usefixtures("enable_experimental_triggers_conditions") +@pytest.mark.parametrize( + ("condition_target_config", "entity_id", "entities_in_target"), + parametrize_target_entities("fan"), +) +@pytest.mark.parametrize( + ("condition", "condition_options", "states"), + [ + *parametrize_condition_states( + condition="fan.is_on", + target_states=[STATE_ON], + other_states=[STATE_OFF], + ), + *parametrize_condition_states( + condition="fan.is_off", + target_states=[STATE_OFF], + other_states=[STATE_ON], + ), + ], +) +async def test_fan_state_condition_behavior_any( + hass: HomeAssistant, + target_fans: list[str], + target_switches: list[str], + condition_target_config: dict, + entity_id: str, + entities_in_target: int, + condition: str, + condition_options: dict[str, Any], + states: list[ConditionStateDescription], +) -> None: + """Test the fan state condition with the 'any' behavior.""" + other_entity_ids = set(target_fans) - {entity_id} + + # Set all fans, including the tested fan, to the initial state + for eid in target_fans: + set_or_remove_state(hass, eid, states[0]["included"]) + await hass.async_block_till_done() + + condition = await create_condition( + hass, + condition=condition, + target=condition_target_config, + behavior="any", + ) + + # Set state for switches to ensure that they don't impact the condition + for state in states: + for eid in target_switches: + set_or_remove_state(hass, eid, state["included"]) + await hass.async_block_till_done() + assert condition(hass) is False + + for state in states: + included_state = state["included"] + set_or_remove_state(hass, entity_id, included_state) + await hass.async_block_till_done() + assert condition(hass) == state["condition_true"] + + # Check if changing other fans also passes the condition + for other_entity_id in other_entity_ids: + set_or_remove_state(hass, other_entity_id, included_state) + await hass.async_block_till_done() + assert condition(hass) == state["condition_true"] + + +@pytest.mark.usefixtures("enable_experimental_triggers_conditions") +@pytest.mark.parametrize( + ("condition_target_config", "entity_id", "entities_in_target"), + parametrize_target_entities("fan"), +) +@pytest.mark.parametrize( + ("condition", "condition_options", "states"), + [ + *parametrize_condition_states( + condition="fan.is_on", + target_states=[STATE_ON], + other_states=[STATE_OFF], + ), + *parametrize_condition_states( + condition="fan.is_off", + target_states=[STATE_OFF], + other_states=[STATE_ON], + ), + ], +) +async def test_fan_state_condition_behavior_all( + hass: HomeAssistant, + target_fans: list[str], + condition_target_config: dict, + entity_id: str, + entities_in_target: int, + condition: str, + condition_options: dict[str, Any], + states: list[ConditionStateDescription], +) -> None: + """Test the fan state condition with the 'all' behavior.""" + # Set state for two switches to ensure that they don't impact the condition + hass.states.async_set("switch.label_switch_1", STATE_OFF) + hass.states.async_set("switch.label_switch_2", STATE_ON) + + other_entity_ids = set(target_fans) - {entity_id} + + # Set all fans, including the tested fan, to the initial state + for eid in target_fans: + set_or_remove_state(hass, eid, states[0]["included"]) + await hass.async_block_till_done() + + condition = await create_condition( + hass, + condition=condition, + target=condition_target_config, + behavior="all", + ) + + for state in states: + included_state = state["included"] + + set_or_remove_state(hass, entity_id, included_state) + await hass.async_block_till_done() + # The condition passes if all entities are either in a target state or invalid + assert condition(hass) == ( + (not state["state_valid"]) + or (state["condition_true"] and entities_in_target == 1) + ) + + for other_entity_id in other_entity_ids: + set_or_remove_state(hass, other_entity_id, included_state) + await hass.async_block_till_done() + + # The condition passes if all entities are either in a target state or invalid + assert condition(hass) == ( + (not state["state_valid"]) or state["condition_true"] + )