mirror of
https://github.com/home-assistant/core.git
synced 2026-04-17 23:53:49 +01:00
222 lines
7.1 KiB
Python
222 lines
7.1 KiB
Python
"""Helper for groups."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from propcache.api import cached_property
|
|
|
|
from homeassistant.const import ATTR_ENTITY_ID, ENTITY_MATCH_ALL, ENTITY_MATCH_NONE
|
|
from homeassistant.core import Event, HomeAssistant, callback
|
|
|
|
from . import entity_registry as er
|
|
from .singleton import singleton
|
|
|
|
if TYPE_CHECKING:
|
|
from .entity import Entity
|
|
|
|
DATA_GROUP_ENTITIES = "group_entities"
|
|
ENTITY_PREFIX = "group."
|
|
|
|
|
|
class Group:
|
|
"""Entity group base class."""
|
|
|
|
_entity: Entity
|
|
|
|
def __init__(self, entity: Entity) -> None:
|
|
"""Initialize the group."""
|
|
self._entity = entity
|
|
|
|
@property
|
|
def member_entity_ids(self) -> list[str]:
|
|
"""Return the list of member entity IDs."""
|
|
raise NotImplementedError
|
|
|
|
@callback
|
|
def async_added_to_hass(self) -> None:
|
|
"""Called when the entity is added to hass."""
|
|
entity = self._entity
|
|
get_group_entities(entity.hass)[entity.entity_id] = entity
|
|
|
|
@callback
|
|
def async_will_remove_from_hass(self) -> None:
|
|
"""Called when the entity will be removed from hass."""
|
|
entity = self._entity
|
|
del get_group_entities(entity.hass)[entity.entity_id]
|
|
|
|
|
|
class GenericGroup(Group):
|
|
"""Generic entity group.
|
|
|
|
Members can come from multiple integrations and are referenced by entity ID.
|
|
"""
|
|
|
|
def __init__(self, entity: Entity, member_entity_ids: list[str]) -> None:
|
|
"""Initialize the group."""
|
|
super().__init__(entity)
|
|
self._member_entity_ids = member_entity_ids
|
|
|
|
@cached_property
|
|
def member_entity_ids(self) -> list[str]:
|
|
"""Return the list of member entity IDs."""
|
|
return self._member_entity_ids
|
|
|
|
|
|
class IntegrationSpecificGroup(Group):
|
|
"""Integration-specific entity group.
|
|
|
|
Members come from a single integration and are referenced by unique ID.
|
|
Entity IDs are resolved via the entity registry. This group listens for
|
|
entity registry events to keep the resolved entity IDs up to date.
|
|
"""
|
|
|
|
_member_entity_ids: list[str] | None = None
|
|
_member_unique_ids: list[str]
|
|
|
|
def __init__(self, entity: Entity, member_unique_ids: list[str]) -> None:
|
|
"""Initialize the group."""
|
|
super().__init__(entity)
|
|
self._member_unique_ids = member_unique_ids
|
|
|
|
@cached_property
|
|
def member_entity_ids(self) -> list[str]:
|
|
"""Return the list of member entity IDs."""
|
|
entity_registry = er.async_get(self._entity.hass)
|
|
self._member_entity_ids = [
|
|
entity_id
|
|
for unique_id in self.member_unique_ids
|
|
if (
|
|
entity_id := entity_registry.async_get_entity_id(
|
|
self._entity.platform.domain,
|
|
self._entity.platform.platform_name,
|
|
unique_id,
|
|
)
|
|
)
|
|
is not None
|
|
]
|
|
return self._member_entity_ids
|
|
|
|
@property
|
|
def member_unique_ids(self) -> list[str]:
|
|
"""Return the list of member unique IDs."""
|
|
return self._member_unique_ids
|
|
|
|
@member_unique_ids.setter
|
|
def member_unique_ids(self, value: list[str]) -> None:
|
|
"""Set the list of member unique IDs."""
|
|
self._member_unique_ids = value
|
|
if self._member_entity_ids is not None:
|
|
self._member_entity_ids = None
|
|
del self.member_entity_ids
|
|
|
|
@callback
|
|
def async_added_to_hass(self) -> None:
|
|
"""Called when the entity is added to hass."""
|
|
super().async_added_to_hass()
|
|
|
|
entity = self._entity
|
|
entity_registry = er.async_get(entity.hass)
|
|
|
|
@callback
|
|
def _handle_entity_registry_updated(event: Event[Any]) -> None:
|
|
"""Handle registry create or update event."""
|
|
if (
|
|
event.data["action"] in {"create", "update"}
|
|
and (entry := entity_registry.async_get(event.data["entity_id"]))
|
|
and entry.domain == entity.platform.domain
|
|
and entry.platform == entity.platform.platform_name
|
|
and entry.unique_id in self.member_unique_ids
|
|
) or (
|
|
event.data["action"] == "remove"
|
|
and self._member_entity_ids is not None
|
|
and event.data["entity_id"] in self._member_entity_ids
|
|
):
|
|
if self._member_entity_ids is not None:
|
|
self._member_entity_ids = None
|
|
del self.member_entity_ids
|
|
entity.async_write_ha_state()
|
|
|
|
entity.async_on_remove(
|
|
entity.hass.bus.async_listen(
|
|
er.EVENT_ENTITY_REGISTRY_UPDATED,
|
|
_handle_entity_registry_updated,
|
|
)
|
|
)
|
|
|
|
|
|
@callback
|
|
@singleton(DATA_GROUP_ENTITIES)
|
|
def get_group_entities(hass: HomeAssistant) -> dict[str, Entity]:
|
|
"""Get the group entities.
|
|
|
|
Items are added to this dict by Group.async_added_to_hass and
|
|
removed by Group.async_will_remove_from_hass.
|
|
"""
|
|
return {}
|
|
|
|
|
|
def expand_entity_ids(hass: HomeAssistant, entity_ids: Iterable[Any]) -> list[str]:
|
|
"""Return entity_ids with group entity ids replaced by their members.
|
|
|
|
Async friendly.
|
|
"""
|
|
group_entities = get_group_entities(hass)
|
|
|
|
found_ids: list[str] = []
|
|
for entity_id in entity_ids:
|
|
if not isinstance(entity_id, str) or entity_id in (
|
|
ENTITY_MATCH_NONE,
|
|
ENTITY_MATCH_ALL,
|
|
):
|
|
continue
|
|
|
|
entity_id = entity_id.lower()
|
|
|
|
# If entity_id points at a group, expand it
|
|
if (entity := group_entities.get(entity_id)) is not None and isinstance(
|
|
entity.group, GenericGroup
|
|
):
|
|
child_entities = entity.group.member_entity_ids
|
|
if entity_id in child_entities:
|
|
child_entities = list(child_entities)
|
|
child_entities.remove(entity_id)
|
|
found_ids.extend(
|
|
ent_id
|
|
for ent_id in expand_entity_ids(hass, child_entities)
|
|
if ent_id not in found_ids
|
|
)
|
|
# If entity_id points at an old-style group, expand it
|
|
elif entity_id.startswith(ENTITY_PREFIX):
|
|
child_entities = get_entity_ids(hass, entity_id)
|
|
if entity_id in child_entities:
|
|
child_entities = list(child_entities)
|
|
child_entities.remove(entity_id)
|
|
found_ids.extend(
|
|
ent_id
|
|
for ent_id in expand_entity_ids(hass, child_entities)
|
|
if ent_id not in found_ids
|
|
)
|
|
elif entity_id not in found_ids:
|
|
found_ids.append(entity_id)
|
|
|
|
return found_ids
|
|
|
|
|
|
def get_entity_ids(
|
|
hass: HomeAssistant, entity_id: str, domain_filter: str | None = None
|
|
) -> list[str]:
|
|
"""Get members of this group.
|
|
|
|
Async friendly.
|
|
"""
|
|
group = hass.states.get(entity_id)
|
|
if not group or ATTR_ENTITY_ID not in group.attributes:
|
|
return []
|
|
entity_ids: list[str] = group.attributes[ATTR_ENTITY_ID]
|
|
if not domain_filter:
|
|
return entity_ids
|
|
domain_filter = f"{domain_filter.lower()}."
|
|
return [ent_id for ent_id in entity_ids if ent_id.startswith(domain_filter)]
|