diff --git a/homeassistant/components/websocket_api/automation.py b/homeassistant/components/websocket_api/automation.py index 1cc9019eb4a..7794012be6f 100644 --- a/homeassistant/components/websocket_api/automation.py +++ b/homeassistant/components/websocket_api/automation.py @@ -33,6 +33,10 @@ FLATTENED_SERVICE_DESCRIPTIONS_CACHE: HassKey[ tuple[dict[str, dict[str, Any]], dict[str, dict[str, Any]]] ] = HassKey("websocket_automation_flat_service_description_cache") +AUTOMATION_COMPONENT_LOOKUP_CACHE: HassKey[ + dict[str, tuple[Mapping[str, Any], _AutomationComponentLookupTable]] +] = HassKey("websocket_automation_component_lookup_cache") + @dataclass(slots=True, kw_only=True) class _EntityFilter: @@ -107,6 +111,14 @@ class _AutomationComponentLookupData: ) +@dataclass(slots=True, kw_only=True) +class _AutomationComponentLookupTable: + """Helper class for looking up automation components.""" + + domain_components: dict[str | None, list[_AutomationComponentLookupData]] + component_count: int + + def _get_automation_component_domains( target_description: dict[str, Any], ) -> set[str | None]: @@ -138,8 +150,51 @@ def _get_automation_component_domains( return domains +def _get_automation_component_lookup_table( + hass: HomeAssistant, + component_type: str, + component_descriptions: Mapping[str, Mapping[str, Any] | None], +) -> _AutomationComponentLookupTable: + """Get a dict of automation components keyed by domain, along with the total number of components. + + Returns a cached object if available. + """ + + try: + cache = hass.data[AUTOMATION_COMPONENT_LOOKUP_CACHE] + except KeyError: + cache = hass.data[AUTOMATION_COMPONENT_LOOKUP_CACHE] = {} + + if (cached := cache.get(component_type)) is not None: + cached_descriptions, cached_lookup = cached + if cached_descriptions is component_descriptions: + _LOGGER.debug( + "Using cached automation component lookup data for %s", component_type + ) + return cached_lookup + + lookup_table = _AutomationComponentLookupTable( + domain_components={}, component_count=0 + ) + for component, description in component_descriptions.items(): + if description is None or CONF_TARGET not in description: + _LOGGER.debug("Skipping component %s without target description", component) + continue + domains = _get_automation_component_domains(description[CONF_TARGET]) + lookup_data = _AutomationComponentLookupData.create( + component, description[CONF_TARGET] + ) + for domain in domains: + lookup_table.domain_components.setdefault(domain, []).append(lookup_data) + lookup_table.component_count += 1 + + cache[component_type] = (component_descriptions, lookup_table) + return lookup_table + + def _async_get_automation_components_for_target( hass: HomeAssistant, + component_type: str, target_selection: ConfigType, expand_group: bool, component_descriptions: Mapping[str, Mapping[str, Any] | None], @@ -155,27 +210,17 @@ def _async_get_automation_components_for_target( ) _LOGGER.debug("Extracted entities for lookup: %s", extracted) - # Build lookup structure: domain -> list of trigger/condition/service lookup data - domain_components: dict[str | None, list[_AutomationComponentLookupData]] = {} - component_count = 0 - for component, description in component_descriptions.items(): - if description is None or CONF_TARGET not in description: - _LOGGER.debug("Skipping component %s without target description", component) - continue - domains = _get_automation_component_domains(description[CONF_TARGET]) - lookup_data = _AutomationComponentLookupData.create( - component, description[CONF_TARGET] - ) - for domain in domains: - domain_components.setdefault(domain, []).append(lookup_data) - component_count += 1 - - _LOGGER.debug("Automation components per domain: %s", domain_components) + lookup_table = _get_automation_component_lookup_table( + hass, component_type, component_descriptions + ) + _LOGGER.debug( + "Automation components per domain: %s", lookup_table.domain_components + ) entity_infos = entity_sources(hass) matched_components: set[str] = set() for entity_id in extracted.referenced | extracted.indirectly_referenced: - if component_count == len(matched_components): + if lookup_table.component_count == len(matched_components): # All automation components matched already, so we don't need to iterate further break @@ -187,7 +232,11 @@ def _async_get_automation_components_for_target( entity_domain = entity_id.split(".")[0] entity_integration = entity_info["domain"] for domain in (entity_domain, entity_integration, None): - for component_data in domain_components.get(domain, []): + if not ( + domain_component_data := lookup_table.domain_components.get(domain) + ): + continue + for component_data in domain_component_data: if component_data.component in matched_components: continue if component_data.matches( @@ -204,7 +253,7 @@ async def async_get_triggers_for_target( """Get triggers for a target.""" descriptions = await async_get_all_trigger_descriptions(hass) return _async_get_automation_components_for_target( - hass, target_selector, expand_group, descriptions + hass, "triggers", target_selector, expand_group, descriptions ) @@ -214,7 +263,7 @@ async def async_get_conditions_for_target( """Get conditions for a target.""" descriptions = await async_get_all_condition_descriptions(hass) return _async_get_automation_components_for_target( - hass, target_selector, expand_group, descriptions + hass, "conditions", target_selector, expand_group, descriptions ) @@ -247,5 +296,9 @@ async def async_get_services_for_target( return flattened_descriptions return _async_get_automation_components_for_target( - hass, target_selector, expand_group, get_flattened_service_descriptions() + hass, + "services", + target_selector, + expand_group, + get_flattened_service_descriptions(), ) diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index 48ca34aa8fd..5d7bfabdb80 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -24,6 +24,10 @@ from homeassistant.components.websocket_api.auth import ( TYPE_AUTH_OK, TYPE_AUTH_REQUIRED, ) +from homeassistant.components.websocket_api.automation import ( + AUTOMATION_COMPONENT_LOOKUP_CACHE, + _get_automation_component_lookup_table, +) from homeassistant.components.websocket_api.commands import ( ALL_CONDITION_DESCRIPTIONS_JSON_CACHE, ALL_SERVICE_DESCRIPTIONS_JSON_CACHE, @@ -3665,6 +3669,7 @@ async def test_get_triggers_conditions_for_target( hass: HomeAssistant, websocket_client: MockHAClientWebSocket, automation_component: str, + caplog: pytest.LogCaptureFixture, ) -> None: """Test get_triggers_for_target/get_conditions_for_target command with mixed target types.""" @@ -3803,7 +3808,9 @@ async def test_get_triggers_conditions_for_target( await hass.async_block_till_done() async def assert_command( - target: dict[str, list[str]], expected: list[str] + target: dict[str, list[str]], + expected: list[str], + expect_lookup_cache: bool = True, ) -> Any: """Call the command and assert expected triggers/conditions.""" await websocket_client.send_json_auto_id( @@ -3815,8 +3822,15 @@ async def test_get_triggers_conditions_for_target( assert msg["success"] assert sorted(msg["result"]) == sorted(expected) + assert ( + "Using cached automation component lookup data" in caplog.text + ) == expect_lookup_cache + caplog.clear() + # Test entity target - unknown entity - await assert_command({"entity_id": ["light.unknown_entity"]}, []) + await assert_command( + {"entity_id": ["light.unknown_entity"]}, [], expect_lookup_cache=False + ) # Test entity target - entity not in registry await assert_command( @@ -3936,6 +3950,7 @@ async def test_get_services_for_target( mock_load_yaml: Mock, hass: HomeAssistant, websocket_client: MockHAClientWebSocket, + caplog: pytest.LogCaptureFixture, ) -> None: """Test get_services_for_target command with mixed target types.""" @@ -4047,7 +4062,11 @@ async def test_get_services_for_target( ) await hass.async_block_till_done() - async def assert_services(target: dict[str, list[str]], expected: list[str]) -> Any: + async def assert_services( + target: dict[str, list[str]], + expected: list[str], + expect_lookup_cache: bool = True, + ) -> Any: """Call the command and assert expected services.""" await websocket_client.send_json_auto_id( {"type": "get_services_for_target", "target": target} @@ -4058,8 +4077,15 @@ async def test_get_services_for_target( assert msg["success"] assert sorted(msg["result"]) == sorted(expected) + assert ( + "Using cached automation component lookup data" in caplog.text + ) == expect_lookup_cache + caplog.clear() + # Test entity target - unknown entity - await assert_services({"entity_id": ["light.unknown_entity"]}, []) + await assert_services( + {"entity_id": ["light.unknown_entity"]}, [], expect_lookup_cache=False + ) # Test entity target - entity not in registry await assert_services( @@ -4212,7 +4238,7 @@ async def test_get_services_for_target_caching( await call_command() assert mock_get_components.call_count == 1 - first_flat_descriptions = mock_get_components.call_args_list[0][0][3] + first_flat_descriptions = mock_get_components.call_args_list[0][0][4] assert first_flat_descriptions == { "light.turn_on": { "fields": {}, @@ -4227,7 +4253,7 @@ async def test_get_services_for_target_caching( # Second call: should reuse cached flat descriptions await call_command() assert mock_get_components.call_count == 2 - second_flat_descriptions = mock_get_components.call_args_list[1][0][3] + second_flat_descriptions = mock_get_components.call_args_list[1][0][4] assert first_flat_descriptions is second_flat_descriptions # Register a new service to invalidate cache @@ -4237,6 +4263,89 @@ async def test_get_services_for_target_caching( # Third call: cache should be rebuilt await call_command() assert mock_get_components.call_count == 3 - third_flat_descriptions = mock_get_components.call_args_list[2][0][3] + third_flat_descriptions = mock_get_components.call_args_list[2][0][4] assert "new_domain.new_service" in third_flat_descriptions assert third_flat_descriptions is not first_flat_descriptions + + +async def test_get_automation_component_lookup_table_cache( + hass: HomeAssistant, +) -> None: + """Test that _get_automation_component_lookup_table caches and rotates properly.""" + triggers: dict[str, dict[str, Any] | None] = { + "light.turned_on": {"target": {"entity": [{"domain": ["light"]}]}}, + "switch.turned_on": {"target": {"entity": [{"domain": ["switch"]}]}}, + } + conditions: dict[str, dict[str, Any] | None] = { + "light.is_on": {"target": {"entity": [{"domain": ["light"]}]}}, + "sensor.is_above": {"target": {"entity": [{"domain": ["sensor"]}]}}, + } + services: dict[str, dict[str, Any] | None] = { + "light.turn_on": {"target": {"entity": [{"domain": ["light"]}]}}, + "climate.set_temperature": {"target": {"entity": [{"domain": ["climate"]}]}}, + } + + # First call with triggers - cache should be created with 1 entry + trigger_result1 = _get_automation_component_lookup_table(hass, "triggers", triggers) + assert AUTOMATION_COMPONENT_LOOKUP_CACHE in hass.data + cache = hass.data[AUTOMATION_COMPONENT_LOOKUP_CACHE] + assert len(cache) == 1 + + # Second call with same triggers - should return cached result + trigger_result2 = _get_automation_component_lookup_table(hass, "triggers", triggers) + assert trigger_result1 is trigger_result2 + assert len(cache) == 1 + + # Call with conditions + condition_result1 = _get_automation_component_lookup_table( + hass, "conditions", conditions + ) + assert condition_result1 is not trigger_result1 + assert len(cache) == 2 + + # Call with services + service_result1 = _get_automation_component_lookup_table(hass, "services", services) + assert service_result1 is not trigger_result1 + assert service_result1 is not condition_result1 + assert len(cache) == 3 + + # Verify all 3 return cached results + assert ( + _get_automation_component_lookup_table(hass, "triggers", triggers) + is trigger_result1 + ) + assert ( + _get_automation_component_lookup_table(hass, "conditions", conditions) + is condition_result1 + ) + assert ( + _get_automation_component_lookup_table(hass, "services", services) + is service_result1 + ) + assert len(cache) == 3 + + # Add a new triggers description dict - replaces previous triggers cache + new_triggers: dict[str, dict[str, Any] | None] = { + "fan.turned_on": {"target": {"entity": [{"domain": ["fan"]}]}}, + } + _get_automation_component_lookup_table(hass, "triggers", new_triggers) + assert len(cache) == 3 + + # Initial trigger cache entry should have been replaced + trigger_result3 = _get_automation_component_lookup_table(hass, "triggers", triggers) + assert trigger_result3 is not trigger_result1 + assert len(cache) == 3 + + # Verify all 3 return cached results again + assert ( + _get_automation_component_lookup_table(hass, "triggers", triggers) + is trigger_result3 + ) + assert ( + _get_automation_component_lookup_table(hass, "conditions", conditions) + is condition_result1 + ) + assert ( + _get_automation_component_lookup_table(hass, "services", services) + is service_result1 + )