1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-24 21:06:19 +00:00

Improve performance of extracting entities by label (#114720)

This commit is contained in:
J. Nick Koston
2024-04-03 10:24:44 -10:00
committed by GitHub
parent 3d8a110908
commit e86fec310b
3 changed files with 46 additions and 27 deletions

View File

@@ -512,11 +512,13 @@ class EntityRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
"""Container for entity registry items, maps entity_id -> entry.
Maintains four additional indexes:
Maintains six additional indexes:
- id -> entry
- (domain, platform, unique_id) -> entity_id
- config_entry_id -> list[key]
- device_id -> list[key]
- config_entry_id -> dict[key, True]
- device_id -> dict[key, True]
- area_id -> dict[key, True]
- label -> dict[key, True]
"""
def __init__(self) -> None:
@@ -527,6 +529,7 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
self._config_entry_id_index: dict[str, dict[str, Literal[True]]] = {}
self._device_id_index: dict[str, dict[str, Literal[True]]] = {}
self._area_id_index: dict[str, dict[str, Literal[True]]] = {}
self._labels_index: dict[str, dict[str, Literal[True]]] = {}
def _index_entry(self, key: str, entry: RegistryEntry) -> None:
"""Index an entry."""
@@ -540,6 +543,8 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
self._device_id_index.setdefault(device_id, {})[key] = True
if (area_id := entry.area_id) is not None:
self._area_id_index.setdefault(area_id, {})[key] = True
for label in entry.labels:
self._labels_index.setdefault(label, {})[key] = True
def _unindex_entry(
self, key: str, replacement_entry: RegistryEntry | None = None
@@ -554,6 +559,9 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
self._unindex_entry_value(key, device_id, self._device_id_index)
if area_id := entry.area_id:
self._unindex_entry_value(key, area_id, self._area_id_index)
if labels := entry.labels:
for label in labels:
self._unindex_entry_value(key, label, self._labels_index)
def get_device_ids(self) -> KeysView[str]:
"""Return device ids."""
@@ -592,6 +600,11 @@ class EntityRegistryItems(BaseRegistryItems[RegistryEntry]):
data = self.data
return [data[key] for key in self._area_id_index.get(area_id, ())]
def get_entries_for_label(self, label: str) -> list[RegistryEntry]:
"""Get entries for label."""
data = self.data
return [data[key] for key in self._labels_index.get(label, ())]
class EntityRegistry(BaseRegistry):
"""Class to hold a registry of entities."""
@@ -1317,7 +1330,7 @@ def async_entries_for_label(
registry: EntityRegistry, label_id: str
) -> list[RegistryEntry]:
"""Return entries that match a label."""
return [entry for entry in registry.entities.values() if label_id in entry.labels]
return registry.entities.get_entries_for_label(label_id)
@callback