1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-20 02:48:57 +00:00

Add lookup caching to get_x_for_target (#157888)

This commit is contained in:
Abílio Costa
2025-12-16 12:17:58 +00:00
committed by GitHub
parent 9ba252d8e3
commit 7eecdc87fd
2 changed files with 190 additions and 28 deletions

View File

@@ -33,6 +33,10 @@ FLATTENED_SERVICE_DESCRIPTIONS_CACHE: HassKey[
tuple[dict[str, dict[str, Any]], dict[str, dict[str, Any]]]
] = HassKey("websocket_automation_flat_service_description_cache")
AUTOMATION_COMPONENT_LOOKUP_CACHE: HassKey[
dict[str, tuple[Mapping[str, Any], _AutomationComponentLookupTable]]
] = HassKey("websocket_automation_component_lookup_cache")
@dataclass(slots=True, kw_only=True)
class _EntityFilter:
@@ -107,6 +111,14 @@ class _AutomationComponentLookupData:
)
@dataclass(slots=True, kw_only=True)
class _AutomationComponentLookupTable:
"""Helper class for looking up automation components."""
domain_components: dict[str | None, list[_AutomationComponentLookupData]]
component_count: int
def _get_automation_component_domains(
target_description: dict[str, Any],
) -> set[str | None]:
@@ -138,8 +150,51 @@ def _get_automation_component_domains(
return domains
def _get_automation_component_lookup_table(
hass: HomeAssistant,
component_type: str,
component_descriptions: Mapping[str, Mapping[str, Any] | None],
) -> _AutomationComponentLookupTable:
"""Get a dict of automation components keyed by domain, along with the total number of components.
Returns a cached object if available.
"""
try:
cache = hass.data[AUTOMATION_COMPONENT_LOOKUP_CACHE]
except KeyError:
cache = hass.data[AUTOMATION_COMPONENT_LOOKUP_CACHE] = {}
if (cached := cache.get(component_type)) is not None:
cached_descriptions, cached_lookup = cached
if cached_descriptions is component_descriptions:
_LOGGER.debug(
"Using cached automation component lookup data for %s", component_type
)
return cached_lookup
lookup_table = _AutomationComponentLookupTable(
domain_components={}, component_count=0
)
for component, description in component_descriptions.items():
if description is None or CONF_TARGET not in description:
_LOGGER.debug("Skipping component %s without target description", component)
continue
domains = _get_automation_component_domains(description[CONF_TARGET])
lookup_data = _AutomationComponentLookupData.create(
component, description[CONF_TARGET]
)
for domain in domains:
lookup_table.domain_components.setdefault(domain, []).append(lookup_data)
lookup_table.component_count += 1
cache[component_type] = (component_descriptions, lookup_table)
return lookup_table
def _async_get_automation_components_for_target(
hass: HomeAssistant,
component_type: str,
target_selection: ConfigType,
expand_group: bool,
component_descriptions: Mapping[str, Mapping[str, Any] | None],
@@ -155,27 +210,17 @@ def _async_get_automation_components_for_target(
)
_LOGGER.debug("Extracted entities for lookup: %s", extracted)
# Build lookup structure: domain -> list of trigger/condition/service lookup data
domain_components: dict[str | None, list[_AutomationComponentLookupData]] = {}
component_count = 0
for component, description in component_descriptions.items():
if description is None or CONF_TARGET not in description:
_LOGGER.debug("Skipping component %s without target description", component)
continue
domains = _get_automation_component_domains(description[CONF_TARGET])
lookup_data = _AutomationComponentLookupData.create(
component, description[CONF_TARGET]
lookup_table = _get_automation_component_lookup_table(
hass, component_type, component_descriptions
)
_LOGGER.debug(
"Automation components per domain: %s", lookup_table.domain_components
)
for domain in domains:
domain_components.setdefault(domain, []).append(lookup_data)
component_count += 1
_LOGGER.debug("Automation components per domain: %s", domain_components)
entity_infos = entity_sources(hass)
matched_components: set[str] = set()
for entity_id in extracted.referenced | extracted.indirectly_referenced:
if component_count == len(matched_components):
if lookup_table.component_count == len(matched_components):
# All automation components matched already, so we don't need to iterate further
break
@@ -187,7 +232,11 @@ def _async_get_automation_components_for_target(
entity_domain = entity_id.split(".")[0]
entity_integration = entity_info["domain"]
for domain in (entity_domain, entity_integration, None):
for component_data in domain_components.get(domain, []):
if not (
domain_component_data := lookup_table.domain_components.get(domain)
):
continue
for component_data in domain_component_data:
if component_data.component in matched_components:
continue
if component_data.matches(
@@ -204,7 +253,7 @@ async def async_get_triggers_for_target(
"""Get triggers for a target."""
descriptions = await async_get_all_trigger_descriptions(hass)
return _async_get_automation_components_for_target(
hass, target_selector, expand_group, descriptions
hass, "triggers", target_selector, expand_group, descriptions
)
@@ -214,7 +263,7 @@ async def async_get_conditions_for_target(
"""Get conditions for a target."""
descriptions = await async_get_all_condition_descriptions(hass)
return _async_get_automation_components_for_target(
hass, target_selector, expand_group, descriptions
hass, "conditions", target_selector, expand_group, descriptions
)
@@ -247,5 +296,9 @@ async def async_get_services_for_target(
return flattened_descriptions
return _async_get_automation_components_for_target(
hass, target_selector, expand_group, get_flattened_service_descriptions()
hass,
"services",
target_selector,
expand_group,
get_flattened_service_descriptions(),
)

View File

@@ -24,6 +24,10 @@ from homeassistant.components.websocket_api.auth import (
TYPE_AUTH_OK,
TYPE_AUTH_REQUIRED,
)
from homeassistant.components.websocket_api.automation import (
AUTOMATION_COMPONENT_LOOKUP_CACHE,
_get_automation_component_lookup_table,
)
from homeassistant.components.websocket_api.commands import (
ALL_CONDITION_DESCRIPTIONS_JSON_CACHE,
ALL_SERVICE_DESCRIPTIONS_JSON_CACHE,
@@ -3665,6 +3669,7 @@ async def test_get_triggers_conditions_for_target(
hass: HomeAssistant,
websocket_client: MockHAClientWebSocket,
automation_component: str,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test get_triggers_for_target/get_conditions_for_target command with mixed target types."""
@@ -3803,7 +3808,9 @@ async def test_get_triggers_conditions_for_target(
await hass.async_block_till_done()
async def assert_command(
target: dict[str, list[str]], expected: list[str]
target: dict[str, list[str]],
expected: list[str],
expect_lookup_cache: bool = True,
) -> Any:
"""Call the command and assert expected triggers/conditions."""
await websocket_client.send_json_auto_id(
@@ -3815,8 +3822,15 @@ async def test_get_triggers_conditions_for_target(
assert msg["success"]
assert sorted(msg["result"]) == sorted(expected)
assert (
"Using cached automation component lookup data" in caplog.text
) == expect_lookup_cache
caplog.clear()
# Test entity target - unknown entity
await assert_command({"entity_id": ["light.unknown_entity"]}, [])
await assert_command(
{"entity_id": ["light.unknown_entity"]}, [], expect_lookup_cache=False
)
# Test entity target - entity not in registry
await assert_command(
@@ -3936,6 +3950,7 @@ async def test_get_services_for_target(
mock_load_yaml: Mock,
hass: HomeAssistant,
websocket_client: MockHAClientWebSocket,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test get_services_for_target command with mixed target types."""
@@ -4047,7 +4062,11 @@ async def test_get_services_for_target(
)
await hass.async_block_till_done()
async def assert_services(target: dict[str, list[str]], expected: list[str]) -> Any:
async def assert_services(
target: dict[str, list[str]],
expected: list[str],
expect_lookup_cache: bool = True,
) -> Any:
"""Call the command and assert expected services."""
await websocket_client.send_json_auto_id(
{"type": "get_services_for_target", "target": target}
@@ -4058,8 +4077,15 @@ async def test_get_services_for_target(
assert msg["success"]
assert sorted(msg["result"]) == sorted(expected)
assert (
"Using cached automation component lookup data" in caplog.text
) == expect_lookup_cache
caplog.clear()
# Test entity target - unknown entity
await assert_services({"entity_id": ["light.unknown_entity"]}, [])
await assert_services(
{"entity_id": ["light.unknown_entity"]}, [], expect_lookup_cache=False
)
# Test entity target - entity not in registry
await assert_services(
@@ -4212,7 +4238,7 @@ async def test_get_services_for_target_caching(
await call_command()
assert mock_get_components.call_count == 1
first_flat_descriptions = mock_get_components.call_args_list[0][0][3]
first_flat_descriptions = mock_get_components.call_args_list[0][0][4]
assert first_flat_descriptions == {
"light.turn_on": {
"fields": {},
@@ -4227,7 +4253,7 @@ async def test_get_services_for_target_caching(
# Second call: should reuse cached flat descriptions
await call_command()
assert mock_get_components.call_count == 2
second_flat_descriptions = mock_get_components.call_args_list[1][0][3]
second_flat_descriptions = mock_get_components.call_args_list[1][0][4]
assert first_flat_descriptions is second_flat_descriptions
# Register a new service to invalidate cache
@@ -4237,6 +4263,89 @@ async def test_get_services_for_target_caching(
# Third call: cache should be rebuilt
await call_command()
assert mock_get_components.call_count == 3
third_flat_descriptions = mock_get_components.call_args_list[2][0][3]
third_flat_descriptions = mock_get_components.call_args_list[2][0][4]
assert "new_domain.new_service" in third_flat_descriptions
assert third_flat_descriptions is not first_flat_descriptions
async def test_get_automation_component_lookup_table_cache(
hass: HomeAssistant,
) -> None:
"""Test that _get_automation_component_lookup_table caches and rotates properly."""
triggers: dict[str, dict[str, Any] | None] = {
"light.turned_on": {"target": {"entity": [{"domain": ["light"]}]}},
"switch.turned_on": {"target": {"entity": [{"domain": ["switch"]}]}},
}
conditions: dict[str, dict[str, Any] | None] = {
"light.is_on": {"target": {"entity": [{"domain": ["light"]}]}},
"sensor.is_above": {"target": {"entity": [{"domain": ["sensor"]}]}},
}
services: dict[str, dict[str, Any] | None] = {
"light.turn_on": {"target": {"entity": [{"domain": ["light"]}]}},
"climate.set_temperature": {"target": {"entity": [{"domain": ["climate"]}]}},
}
# First call with triggers - cache should be created with 1 entry
trigger_result1 = _get_automation_component_lookup_table(hass, "triggers", triggers)
assert AUTOMATION_COMPONENT_LOOKUP_CACHE in hass.data
cache = hass.data[AUTOMATION_COMPONENT_LOOKUP_CACHE]
assert len(cache) == 1
# Second call with same triggers - should return cached result
trigger_result2 = _get_automation_component_lookup_table(hass, "triggers", triggers)
assert trigger_result1 is trigger_result2
assert len(cache) == 1
# Call with conditions
condition_result1 = _get_automation_component_lookup_table(
hass, "conditions", conditions
)
assert condition_result1 is not trigger_result1
assert len(cache) == 2
# Call with services
service_result1 = _get_automation_component_lookup_table(hass, "services", services)
assert service_result1 is not trigger_result1
assert service_result1 is not condition_result1
assert len(cache) == 3
# Verify all 3 return cached results
assert (
_get_automation_component_lookup_table(hass, "triggers", triggers)
is trigger_result1
)
assert (
_get_automation_component_lookup_table(hass, "conditions", conditions)
is condition_result1
)
assert (
_get_automation_component_lookup_table(hass, "services", services)
is service_result1
)
assert len(cache) == 3
# Add a new triggers description dict - replaces previous triggers cache
new_triggers: dict[str, dict[str, Any] | None] = {
"fan.turned_on": {"target": {"entity": [{"domain": ["fan"]}]}},
}
_get_automation_component_lookup_table(hass, "triggers", new_triggers)
assert len(cache) == 3
# Initial trigger cache entry should have been replaced
trigger_result3 = _get_automation_component_lookup_table(hass, "triggers", triggers)
assert trigger_result3 is not trigger_result1
assert len(cache) == 3
# Verify all 3 return cached results again
assert (
_get_automation_component_lookup_table(hass, "triggers", triggers)
is trigger_result3
)
assert (
_get_automation_component_lookup_table(hass, "conditions", conditions)
is condition_result1
)
assert (
_get_automation_component_lookup_table(hass, "services", services)
is service_result1
)