diff --git a/tests/components/__init__.py b/tests/components/__init__.py index 8c251b7d27b..d103b233971 100644 --- a/tests/components/__init__.py +++ b/tests/components/__init__.py @@ -172,6 +172,90 @@ class StateDescription(TypedDict): count: int +class ConditionStateDescription(TypedDict): + """Test state and expected service call count.""" + + included: _StateDescription + excluded: _StateDescription + condition_true: bool + state_valid: bool + + +def parametrize_condition_states( + *, + condition: str, + condition_options: dict[str, Any] | None = None, + target_states: list[str | None | tuple[str | None, dict]], + other_states: list[str | None | tuple[str | None, dict]], + additional_attributes: dict | None = None, +) -> list[tuple[str, dict[str, Any], list[ConditionStateDescription]]]: + """Parametrize states and expected service call counts. + + The target_states and other_states iterables are either iterables of + states or iterables of (state, attributes) tuples. + + Returns a list of tuples with (condition, condition options, list of states), + where states is a list of ConditionStateDescription dicts. + """ + + additional_attributes = additional_attributes or {} + condition_options = condition_options or {} + + def state_with_attributes( + state: str | None | tuple[str | None, dict], + condition_true: bool, + state_valid: bool, + ) -> ConditionStateDescription: + """Return (state, attributes) dict.""" + if isinstance(state, str) or state is None: + return { + "included": { + "state": state, + "attributes": additional_attributes, + }, + "excluded": { + "state": state, + "attributes": {}, + }, + "condition_true": condition_true, + "state_valid": state_valid, + } + return { + "included": { + "state": state[0], + "attributes": state[1] | additional_attributes, + }, + "excluded": { + "state": state[0], + "attributes": state[1], + }, + "condition_true": condition_true, + "state_valid": state_valid, + } + + return [ + ( + condition, + condition_options, + list( + itertools.chain( + (state_with_attributes(None, False, False),), + (state_with_attributes(STATE_UNAVAILABLE, False, False),), + (state_with_attributes(STATE_UNKNOWN, False, False),), + ( + state_with_attributes(other_state, False, True) + for other_state in other_states + ), + ( + state_with_attributes(target_state, True, True) + for target_state in target_states + ), + ) + ), + ), + ] + + def parametrize_trigger_states( *, trigger: str, @@ -202,7 +286,7 @@ def parametrize_trigger_states( def state_with_attributes( state: str | None | tuple[str | None, dict], count: int - ) -> dict: + ) -> StateDescription: """Return (state, attributes) dict.""" if isinstance(state, str) or state is None: return { diff --git a/tests/components/light/test_condition.py b/tests/components/light/test_condition.py index 8eb9fb4c50c..d2f3c100c98 100644 --- a/tests/components/light/test_condition.py +++ b/tests/components/light/test_condition.py @@ -1,6 +1,7 @@ """Test light conditions.""" from collections.abc import Generator +from typing import Any from unittest.mock import patch import pytest @@ -13,24 +14,18 @@ from homeassistant.const import ( CONF_TARGET, STATE_OFF, STATE_ON, - STATE_UNAVAILABLE, - STATE_UNKNOWN, ) from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.setup import async_setup_component from tests.components import ( + ConditionStateDescription, + parametrize_condition_states, parametrize_target_entities, set_or_remove_state, target_entities, ) -INVALID_STATES = [ - {"state": STATE_UNAVAILABLE, "attributes": {}}, - {"state": STATE_UNKNOWN, "attributes": {}}, - {"state": None, "attributes": {}}, -] - @pytest.fixture(autouse=True, name="stub_blueprint_populate") def stub_blueprint_populate_autouse(stub_blueprint_populate: None) -> None: @@ -76,15 +71,15 @@ async def setup_automation_with_light_condition( ) -async def has_call_after_trigger( +async def has_single_call_after_trigger( hass: HomeAssistant, service_calls: list[ServiceCall] ) -> bool: - """Check if there are service calls after the trigger event.""" + """Check if there is a single service call after the trigger event.""" hass.bus.async_fire("test_event") await hass.async_block_till_done() - has_calls = len(service_calls) == 1 + num_calls = len(service_calls) service_calls.clear() - return has_calls + return num_calls == 1 @pytest.fixture(name="enable_experimental_triggers_conditions") @@ -125,17 +120,17 @@ async def test_light_conditions_gated_by_labs_flag( parametrize_target_entities("light"), ) @pytest.mark.parametrize( - ("condition", "target_state", "other_state"), + ("condition", "condition_options", "states"), [ - ( - "light.is_on", - {"state": STATE_ON, "attributes": {}}, - {"state": STATE_OFF, "attributes": {}}, + *parametrize_condition_states( + condition="light.is_on", + target_states=[STATE_ON], + other_states=[STATE_OFF], ), - ( - "light.is_off", - {"state": STATE_OFF, "attributes": {}}, - {"state": STATE_ON, "attributes": {}}, + *parametrize_condition_states( + condition="light.is_off", + target_states=[STATE_OFF], + other_states=[STATE_ON], ), ], ) @@ -148,15 +143,15 @@ async def test_light_state_condition_behavior_any( entity_id: str, entities_in_target: int, condition: str, - target_state: str, - other_state: str, + condition_options: dict[str, Any], + states: list[ConditionStateDescription], ) -> None: """Test the light state condition with the 'any' behavior.""" other_entity_ids = set(target_lights) - {entity_id} # Set all lights, including the tested light, to the initial state for eid in target_lights: - set_or_remove_state(hass, eid, other_state) + set_or_remove_state(hass, eid, states[0]["included"]) await hass.async_block_till_done() await setup_automation_with_light_condition( @@ -167,38 +162,29 @@ async def test_light_state_condition_behavior_any( ) # Set state for switches to ensure that they don't impact the condition - for eid in target_switches: - set_or_remove_state(hass, eid, other_state) - await hass.async_block_till_done() - assert not await has_call_after_trigger(hass, service_calls) + for state in states: + for eid in target_switches: + set_or_remove_state(hass, eid, state["included"]) + await hass.async_block_till_done() + assert not await has_single_call_after_trigger(hass, service_calls) - for eid in target_switches: - set_or_remove_state(hass, eid, target_state) - await hass.async_block_till_done() - assert not await has_call_after_trigger(hass, service_calls) - - # Set one light to the condition state -> condition pass - set_or_remove_state(hass, entity_id, target_state) - assert await has_call_after_trigger(hass, service_calls) - - # Set all remaining lights to the condition state -> condition pass - for eid in other_entity_ids: - set_or_remove_state(hass, eid, target_state) - assert await has_call_after_trigger(hass, service_calls) - - for invalid_state in INVALID_STATES: - # Set one light to the invalid state -> condition pass if there are - # other lights in the condition state - set_or_remove_state(hass, entity_id, invalid_state) - assert await has_call_after_trigger(hass, service_calls) == bool( - entities_in_target - 1 + for state in states: + included_state = state["included"] + set_or_remove_state(hass, entity_id, included_state) + await hass.async_block_till_done() + assert ( + await has_single_call_after_trigger(hass, service_calls) + == state["condition_true"] ) - for invalid_state in INVALID_STATES: - # Set all lights to invalid state -> condition fail - for eid in other_entity_ids: - set_or_remove_state(hass, eid, invalid_state) - assert not await has_call_after_trigger(hass, service_calls) + # Check if changing other lights also passes the condition + for other_entity_id in other_entity_ids: + set_or_remove_state(hass, other_entity_id, included_state) + await hass.async_block_till_done() + assert ( + await has_single_call_after_trigger(hass, service_calls) + == state["condition_true"] + ) @pytest.mark.usefixtures("enable_experimental_triggers_conditions") @@ -207,17 +193,17 @@ async def test_light_state_condition_behavior_any( parametrize_target_entities("light"), ) @pytest.mark.parametrize( - ("condition", "target_state", "other_state"), + ("condition", "condition_options", "states"), [ - ( - "light.is_on", - {"state": STATE_ON, "attributes": {}}, - {"state": STATE_OFF, "attributes": {}}, + *parametrize_condition_states( + condition="light.is_on", + target_states=[STATE_ON], + other_states=[STATE_OFF], ), - ( - "light.is_off", - {"state": STATE_OFF, "attributes": {}}, - {"state": STATE_ON, "attributes": {}}, + *parametrize_condition_states( + condition="light.is_off", + target_states=[STATE_OFF], + other_states=[STATE_ON], ), ], ) @@ -229,8 +215,8 @@ async def test_light_state_condition_behavior_all( entity_id: str, entities_in_target: int, condition: str, - target_state: str, - other_state: str, + condition_options: dict[str, Any], + states: list[ConditionStateDescription], ) -> None: """Test the light state condition with the 'all' behavior.""" # Set state for two switches to ensure that they don't impact the condition @@ -241,7 +227,7 @@ async def test_light_state_condition_behavior_all( # Set all lights, including the tested light, to the initial state for eid in target_lights: - set_or_remove_state(hass, eid, other_state) + set_or_remove_state(hass, eid, states[0]["included"]) await hass.async_block_till_done() await setup_automation_with_light_condition( @@ -251,27 +237,22 @@ async def test_light_state_condition_behavior_all( behavior="all", ) - # No lights on the condition state - assert not await has_call_after_trigger(hass, service_calls) + for state in states: + included_state = state["included"] - # Set one light to the condition state -> condition fail - set_or_remove_state(hass, entity_id, target_state) - assert await has_call_after_trigger(hass, service_calls) == ( - entities_in_target == 1 - ) + set_or_remove_state(hass, entity_id, included_state) + await hass.async_block_till_done() + # The condition passes if all entities are either in a target state or invalid + assert await has_single_call_after_trigger(hass, service_calls) == ( + (not state["state_valid"]) + or (state["condition_true"] and entities_in_target == 1) + ) - # Set all remaining lights to the condition state -> condition pass - for eid in other_entity_ids: - set_or_remove_state(hass, eid, target_state) - assert await has_call_after_trigger(hass, service_calls) + for other_entity_id in other_entity_ids: + set_or_remove_state(hass, other_entity_id, included_state) + await hass.async_block_till_done() - for invalid_state in INVALID_STATES: - # Set one light to the invalid state -> condition still pass - set_or_remove_state(hass, entity_id, invalid_state) - assert await has_call_after_trigger(hass, service_calls) - - for invalid_state in INVALID_STATES: - # Set all lights to unavailable -> condition passes - for eid in other_entity_ids: - set_or_remove_state(hass, eid, invalid_state) - assert await has_call_after_trigger(hass, service_calls) + # The condition passes if all entities are either in a target state or invalid + assert await has_single_call_after_trigger(hass, service_calls) == ( + (not state["state_valid"]) or state["condition_true"] + )