diff --git a/homeassistant/helpers/target.py b/homeassistant/helpers/target.py index 0b902ea4d23..5286daaeef0 100644 --- a/homeassistant/helpers/target.py +++ b/homeassistant/helpers/target.py @@ -268,11 +268,13 @@ class TargetStateChangeTracker: hass: HomeAssistant, selector_data: TargetSelectorData, action: Callable[[TargetStateChangedData], Any], + entity_filter: Callable[[set[str]], set[str]], ) -> None: """Initialize the state change tracker.""" self._hass = hass self._selector_data = selector_data self._action = action + self._entity_filter = entity_filter self._state_change_unsub: CALLBACK_TYPE | None = None self._registry_unsubs: list[CALLBACK_TYPE] = [] @@ -289,7 +291,9 @@ class TargetStateChangeTracker: self._hass, self._selector_data, expand_group=False ) - tracked_entities = selected.referenced.union(selected.indirectly_referenced) + tracked_entities = self._entity_filter( + selected.referenced.union(selected.indirectly_referenced) + ) @callback def state_change_listener(event: Event[EventStateChangedData]) -> None: @@ -348,6 +352,7 @@ def async_track_target_selector_state_change_event( hass: HomeAssistant, target_selector_config: ConfigType, action: Callable[[TargetStateChangedData], Any], + entity_filter: Callable[[set[str]], set[str]] = lambda x: x, ) -> CALLBACK_TYPE: """Track state changes for entities referenced directly or indirectly in a target selector.""" selector_data = TargetSelectorData(target_selector_config) @@ -355,5 +360,5 @@ def async_track_target_selector_state_change_event( raise HomeAssistantError( f"Target selector {target_selector_config} does not have any selectors defined" ) - tracker = TargetStateChangeTracker(hass, selector_data, action) + tracker = TargetStateChangeTracker(hass, selector_data, action, entity_filter) return tracker.async_setup() diff --git a/tests/helpers/test_target.py b/tests/helpers/test_target.py index fa31ef375fd..09fb16cbe9a 100644 --- a/tests/helpers/test_target.py +++ b/tests/helpers/test_target.py @@ -36,6 +36,29 @@ from tests.common import ( ) +async def set_states_and_check_target_events( + hass: HomeAssistant, + events: list[target.TargetStateChangedData], + state: str, + entities_to_set_state: list[str], + entities_to_assert_change: list[str], +) -> None: + """Toggle the state entities and check for events.""" + for entity_id in entities_to_set_state: + hass.states.async_set(entity_id, state) + await hass.async_block_till_done() + + assert len(events) == len(entities_to_assert_change) + entities_seen = set() + for event in events: + state_change_event = event.state_change_event + entities_seen.add(state_change_event.data["entity_id"]) + assert state_change_event.data["new_state"].state == state + assert event.targeted_entity_ids == set(entities_to_assert_change) + assert entities_seen == set(entities_to_assert_change) + events.clear() + + @pytest.fixture def registries_mock(hass: HomeAssistant) -> None: """Mock including floor and area info.""" @@ -497,19 +520,9 @@ async def test_async_track_target_selector_state_change_event( """Toggle the state entities and check for events.""" nonlocal last_state last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF - for entity_id in entities_to_set_state: - hass.states.async_set(entity_id, last_state) - await hass.async_block_till_done() - - assert len(events) == len(entities_to_assert_change) - entities_seen = set() - for event in events: - state_change_event = event.state_change_event - entities_seen.add(state_change_event.data["entity_id"]) - assert state_change_event.data["new_state"].state == last_state - assert event.targeted_entity_ids == set(entities_to_assert_change) - assert entities_seen == set(entities_to_assert_change) - events.clear() + await set_states_and_check_target_events( + hass, events, last_state, entities_to_set_state, entities_to_assert_change + ) config_entry = MockConfigEntry(domain="test") config_entry.add_to_hass(hass) @@ -645,3 +658,91 @@ async def test_async_track_target_selector_state_change_event( # After unsubscribing, changes should not trigger unsub() await set_states_and_check_events(targeted_entities, []) + + +async def test_async_track_target_selector_state_change_event_filter( + hass: HomeAssistant, +) -> None: + """Test async_track_target_selector_state_change_event with entity filter.""" + events: list[target.TargetStateChangedData] = [] + + filtered_entity = "" + + @callback + def entity_filter(entity_ids: set[str]) -> set[str]: + return {entity_id for entity_id in entity_ids if entity_id != filtered_entity} + + @callback + def state_change_callback(event: target.TargetStateChangedData): + """Handle state change events.""" + events.append(event) + + last_state = STATE_OFF + + async def set_states_and_check_events( + entities_to_set_state: list[str], entities_to_assert_change: list[str] + ) -> None: + """Toggle the state entities and check for events.""" + nonlocal last_state + last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF + await set_states_and_check_target_events( + hass, events, last_state, entities_to_set_state, entities_to_assert_change + ) + + config_entry = MockConfigEntry(domain="test") + config_entry.add_to_hass(hass) + + entity_reg = er.async_get(hass) + + label = lr.async_get(hass).async_create("Test Label").name + label_entity = entity_reg.async_get_or_create( + domain="light", + platform="test", + unique_id="label_light", + ).entity_id + entity_reg.async_update_entity(label_entity, labels={label}) + + targeted_entity = "light.test_light" + + targeted_entities = [targeted_entity, label_entity] + await set_states_and_check_events(targeted_entities, []) + + selector_config = { + ATTR_ENTITY_ID: targeted_entity, + ATTR_LABEL_ID: label, + } + unsub = target.async_track_target_selector_state_change_event( + hass, selector_config, state_change_callback, entity_filter + ) + + await set_states_and_check_events( + targeted_entities, [targeted_entity, label_entity] + ) + + filtered_entity = targeted_entity + # Fire an event so that the targeted entities are re-evaluated + hass.bus.async_fire( + er.EVENT_ENTITY_REGISTRY_UPDATED, + { + "action": "update", + "entity_id": "light.other", + "changes": {}, + }, + ) + await set_states_and_check_events([targeted_entity, label_entity], [label_entity]) + + filtered_entity = label_entity + # Fire an event so that the targeted entities are re-evaluated + hass.bus.async_fire( + er.EVENT_ENTITY_REGISTRY_UPDATED, + { + "action": "update", + "entity_id": "light.other", + "changes": {}, + }, + ) + await set_states_and_check_events( + [targeted_entity, label_entity], [targeted_entity] + ) + + unsub()