diff --git a/homeassistant/components/text/trigger.py b/homeassistant/components/text/trigger.py index 7da70ab00b8..af2480bf888 100644 --- a/homeassistant/components/text/trigger.py +++ b/homeassistant/components/text/trigger.py @@ -1,5 +1,6 @@ -"""Provides triggers for texts.""" +"""Provides triggers for text and input_text entities.""" +from homeassistant.components.input_text import DOMAIN as INPUT_TEXT_DOMAIN from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN from homeassistant.core import HomeAssistant, State from homeassistant.helpers.automation import DomainSpec @@ -13,9 +14,9 @@ from .const import DOMAIN class TextChangedTrigger(EntityTriggerBase): - """Trigger for text entity when its content changes.""" + """Trigger for text and input_text entities when their content changes.""" - _domain_specs = {DOMAIN: DomainSpec()} + _domain_specs = {DOMAIN: DomainSpec(), INPUT_TEXT_DOMAIN: DomainSpec()} _schema = ENTITY_STATE_TRIGGER_SCHEMA def is_valid_transition(self, from_state: State, to_state: State) -> bool: @@ -35,5 +36,5 @@ TRIGGERS: dict[str, type[Trigger]] = { async def async_get_triggers(hass: HomeAssistant) -> dict[str, type[Trigger]]: - """Return the triggers for texts.""" + """Return the triggers for text and input_text entities.""" return TRIGGERS diff --git a/homeassistant/components/text/triggers.yaml b/homeassistant/components/text/triggers.yaml index 0cc0cfcf6d2..50f3b4f5987 100644 --- a/homeassistant/components/text/triggers.yaml +++ b/homeassistant/components/text/triggers.yaml @@ -1,4 +1,6 @@ changed: target: entity: - domain: text + domain: + - text + - input_text diff --git a/tests/components/common.py b/tests/components/common.py index 36197e0d117..7f639f029b7 100644 --- a/tests/components/common.py +++ b/tests/components/common.py @@ -184,14 +184,19 @@ class StateDescription(TypedDict): attributes: dict -class TriggerStateDescription(TypedDict): - """Test state and expected service call count.""" +class BasicTriggerStateDescription(TypedDict): + """Test state and expected service call count for targeted entities only.""" included_state: StateDescription # State for entities meant to be targeted - excluded_state: StateDescription # State for entities not meant to be targeted count: int # Expected service call count +class TriggerStateDescription(BasicTriggerStateDescription): + """Test state and expected service call count for both included and excluded entities.""" + + excluded_state: StateDescription # State for entities not meant to be targeted + + class ConditionStateDescription(TypedDict): """Test state and expected condition evaluation.""" diff --git a/tests/components/text/test_trigger.py b/tests/components/text/test_trigger.py index 5092d9ae0b9..1a76eaed44f 100644 --- a/tests/components/text/test_trigger.py +++ b/tests/components/text/test_trigger.py @@ -2,11 +2,13 @@ import pytest +from homeassistant.components.input_text import DOMAIN as INPUT_TEXT_DOMAIN +from homeassistant.components.text.const import DOMAIN from homeassistant.const import CONF_ENTITY_ID, STATE_UNAVAILABLE, STATE_UNKNOWN from homeassistant.core import HomeAssistant, ServiceCall from tests.components.common import ( - TriggerStateDescription, + BasicTriggerStateDescription, arm_trigger, assert_trigger_gated_by_labs_flag, parametrize_target_entities, @@ -14,11 +16,114 @@ from tests.components.common import ( target_entities, ) +TEST_TRIGGER_STATES = [ + ( + "text.changed", + [ + { + "included_state": {"state": None, "attributes": {}}, + "count": 0, + }, + { + "included_state": {"state": "bar", "attributes": {}}, + "count": 0, + }, + { + "included_state": {"state": "baz", "attributes": {}}, + "count": 1, + }, + ], + ), + ( + "text.changed", + [ + { + "included_state": {"state": "foo", "attributes": {}}, + "count": 0, + }, + { + "included_state": {"state": "bar", "attributes": {}}, + "count": 1, + }, + { + "included_state": {"state": "baz", "attributes": {}}, + "count": 1, + }, + ], + ), + ( + "text.changed", + [ + { + "included_state": {"state": "foo", "attributes": {}}, + "count": 0, + }, + # empty string + { + "included_state": {"state": "", "attributes": {}}, + "count": 1, + }, + { + "included_state": {"state": "baz", "attributes": {}}, + "count": 1, + }, + ], + ), + ( + "text.changed", + [ + { + "included_state": {"state": STATE_UNAVAILABLE, "attributes": {}}, + "count": 0, + }, + { + "included_state": {"state": "bar", "attributes": {}}, + "count": 0, + }, + { + "included_state": {"state": "baz", "attributes": {}}, + "count": 1, + }, + { + "included_state": {"state": STATE_UNAVAILABLE, "attributes": {}}, + "count": 0, + }, + ], + ), + ( + "text.changed", + [ + { + "included_state": {"state": STATE_UNKNOWN, "attributes": {}}, + "count": 0, + }, + { + "included_state": {"state": "bar", "attributes": {}}, + "count": 0, + }, + { + "included_state": {"state": "baz", "attributes": {}}, + "count": 1, + }, + { + "included_state": {"state": STATE_UNKNOWN, "attributes": {}}, + "count": 0, + }, + ], + ), +] + @pytest.fixture async def target_texts(hass: HomeAssistant) -> dict[str, list[str]]: """Create multiple text entities associated with different targets.""" - return await target_entities(hass, "text") + return await target_entities(hass, DOMAIN) + + +@pytest.fixture +async def target_input_texts(hass: HomeAssistant) -> dict[str, list[str]]: + """Create multiple input_text entities associated with different targets.""" + return await target_entities(hass, INPUT_TEXT_DOMAIN) @pytest.mark.parametrize("trigger_key", ["text.changed"]) @@ -32,110 +137,9 @@ async def test_text_triggers_gated_by_labs_flag( @pytest.mark.usefixtures("enable_labs_preview_features") @pytest.mark.parametrize( ("trigger_target_config", "entity_id", "entities_in_target"), - parametrize_target_entities("text"), -) -@pytest.mark.parametrize( - ("trigger", "states"), - [ - ( - "text.changed", - [ - { - "included_state": {"state": None, "attributes": {}}, - "count": 0, - }, - { - "included_state": {"state": "bar", "attributes": {}}, - "count": 0, - }, - { - "included_state": {"state": "baz", "attributes": {}}, - "count": 1, - }, - ], - ), - ( - "text.changed", - [ - { - "included_state": {"state": "foo", "attributes": {}}, - "count": 0, - }, - { - "included_state": {"state": "bar", "attributes": {}}, - "count": 1, - }, - { - "included_state": {"state": "baz", "attributes": {}}, - "count": 1, - }, - ], - ), - ( - "text.changed", - [ - { - "included_state": {"state": "foo", "attributes": {}}, - "count": 0, - }, - # empty string - {"included_state": {"state": "", "attributes": {}}, "count": 1}, - { - "included_state": {"state": "baz", "attributes": {}}, - "count": 1, - }, - ], - ), - ( - "text.changed", - [ - { - "included_state": { - "state": STATE_UNAVAILABLE, - "attributes": {}, - }, - "count": 0, - }, - { - "included_state": {"state": "bar", "attributes": {}}, - "count": 0, - }, - { - "included_state": {"state": "baz", "attributes": {}}, - "count": 1, - }, - { - "included_state": { - "state": STATE_UNAVAILABLE, - "attributes": {}, - }, - "count": 0, - }, - ], - ), - ( - "text.changed", - [ - { - "included_state": {"state": STATE_UNKNOWN, "attributes": {}}, - "count": 0, - }, - { - "included_state": {"state": "bar", "attributes": {}}, - "count": 0, - }, - { - "included_state": {"state": "baz", "attributes": {}}, - "count": 1, - }, - { - "included_state": {"state": STATE_UNKNOWN, "attributes": {}}, - "count": 0, - }, - ], - ), - ], + parametrize_target_entities(DOMAIN), ) +@pytest.mark.parametrize(("trigger", "states"), TEST_TRIGGER_STATES) async def test_text_state_trigger( hass: HomeAssistant, service_calls: list[ServiceCall], @@ -144,7 +148,7 @@ async def test_text_state_trigger( entity_id: str, entities_in_target: int, trigger: str, - states: list[TriggerStateDescription], + states: list[BasicTriggerStateDescription], ) -> None: """Test that the text state trigger fires when targeted text state changes.""" other_entity_ids = set(target_texts["included_entities"]) - {entity_id} @@ -152,7 +156,7 @@ async def test_text_state_trigger( # Set all texts, including the tested text, to the initial state for eid in target_texts["included_entities"]: set_or_remove_state(hass, eid, states[0]["included_state"]) - await hass.async_block_till_done() + await hass.async_block_till_done() await arm_trigger(hass, trigger, None, trigger_target_config) @@ -168,6 +172,49 @@ async def test_text_state_trigger( # Check if changing other texts also triggers for other_entity_id in other_entity_ids: set_or_remove_state(hass, other_entity_id, included_state) - await hass.async_block_till_done() + await hass.async_block_till_done() + assert len(service_calls) == (entities_in_target - 1) * state["count"] + service_calls.clear() + + +@pytest.mark.usefixtures("enable_labs_preview_features") +@pytest.mark.parametrize( + ("trigger_target_config", "entity_id", "entities_in_target"), + parametrize_target_entities(INPUT_TEXT_DOMAIN), +) +@pytest.mark.parametrize(("trigger", "states"), TEST_TRIGGER_STATES) +async def test_input_text_state_trigger( + hass: HomeAssistant, + service_calls: list[ServiceCall], + target_input_texts: dict[str, list[str]], + trigger_target_config: dict, + entity_id: str, + entities_in_target: int, + trigger: str, + states: list[BasicTriggerStateDescription], +) -> None: + """Test that the `text.changed` trigger fires when any input_text entity's state changes.""" + other_entity_ids = set(target_input_texts["included_entities"]) - {entity_id} + + # Set all input_texts, including the tested input_text, to the initial state + for eid in target_input_texts["included_entities"]: + set_or_remove_state(hass, eid, states[0]["included_state"]) + await hass.async_block_till_done() + + await arm_trigger(hass, trigger, None, trigger_target_config) + + for state in states[1:]: + included_state = state["included_state"] + set_or_remove_state(hass, entity_id, included_state) + await hass.async_block_till_done() + assert len(service_calls) == state["count"] + for service_call in service_calls: + assert service_call.data[CONF_ENTITY_ID] == entity_id + service_calls.clear() + + # Check if changing other input_texts also triggers + for other_entity_id in other_entity_ids: + set_or_remove_state(hass, other_entity_id, included_state) + await hass.async_block_till_done() assert len(service_calls) == (entities_in_target - 1) * state["count"] service_calls.clear()