1
0
mirror of https://github.com/home-assistant/core.git synced 2026-06-01 13:14:35 +01:00
Files
core/homeassistant/components/thread/dataset_store.py
T
2026-04-30 21:14:48 +02:00

515 lines
20 KiB
Python

"""Persistently store thread datasets."""
from asyncio import Event, Task, wait
import dataclasses
from datetime import datetime
import logging
from pprint import pformat
from typing import Any, cast
from propcache.api import cached_property
from python_otbr_api import tlv_parser
from python_otbr_api.tlv_parser import MeshcopTLVType
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.redact import REDACTED
from homeassistant.helpers.singleton import singleton
from homeassistant.helpers.storage import Store
from homeassistant.util import dt as dt_util, ulid as ulid_util
from . import discovery
BORDER_AGENT_DISCOVERY_TIMEOUT = 30
DATA_STORE = "thread.datasets"
STORAGE_KEY = "thread.datasets"
STORAGE_VERSION_MAJOR = 1
STORAGE_VERSION_MINOR = 4
SAVE_DELAY = 10
_LOGGER = logging.getLogger(__name__)
def _format_dataset(
dataset: dict[MeshcopTLVType | int, tlv_parser.MeshcopTLVItem],
) -> dict[str, str]:
"""Format a parsed Thread dataset for logging.
Returns a human-readable dict with enum field names as keys, redacting
NETWORKKEY and PSKC to avoid logging sensitive network credentials.
"""
result = {}
for key, value in dataset.items():
name = key.name if isinstance(key, MeshcopTLVType) else str(key)
if key in (MeshcopTLVType.NETWORKKEY, MeshcopTLVType.PSKC):
result[name] = REDACTED
else:
result[name] = str(value)
return result
class DatasetPreferredError(HomeAssistantError):
"""Raised when attempting to delete the preferred dataset."""
@dataclasses.dataclass(frozen=True)
class DatasetEntry:
"""Dataset store entry."""
preferred_border_agent_id: str | None
preferred_extended_address: str | None
source: str
tlv: str
created: datetime = dataclasses.field(default_factory=dt_util.utcnow)
id: str = dataclasses.field(default_factory=ulid_util.ulid_now)
@property
def channel(self) -> int | None:
"""Return channel as an integer."""
if (channel := self.dataset.get(MeshcopTLVType.CHANNEL)) is None:
return None
return cast(tlv_parser.Channel, channel).channel
@cached_property
def dataset(self) -> dict[MeshcopTLVType | int, tlv_parser.MeshcopTLVItem]:
"""Return the dataset in dict format."""
return tlv_parser.parse_tlv(self.tlv)
@property
def extended_pan_id(self) -> str:
"""Return extended PAN ID as a hex string."""
return str(self.dataset[MeshcopTLVType.EXTPANID])
@property
def network_name(self) -> str | None:
"""Return network name as a string."""
if (name := self.dataset.get(MeshcopTLVType.NETWORKNAME)) is None:
return None
return cast(tlv_parser.NetworkName, name).name
@property
def pan_id(self) -> str | None:
"""Return PAN ID as a hex string."""
return str(self.dataset.get(MeshcopTLVType.PANID))
def to_json(self) -> dict[str, Any]:
"""Return a JSON serializable representation for storage."""
return {
"created": self.created.isoformat(),
"id": self.id,
"preferred_border_agent_id": self.preferred_border_agent_id,
"preferred_extended_address": self.preferred_extended_address,
"source": self.source,
"tlv": self.tlv,
}
class DatasetStoreStore(Store):
"""Store Thread datasets."""
async def _async_migrate_func(
self, old_major_version: int, old_minor_version: int, old_data: dict[str, Any]
) -> dict[str, Any]:
"""Migrate to the new version."""
if old_major_version == 1:
data = old_data
if old_minor_version < 2:
# Deduplicate datasets
datasets: dict[str, DatasetEntry] = {}
preferred_dataset = old_data["preferred_dataset"]
for dataset in old_data["datasets"]:
created = cast(datetime, dt_util.parse_datetime(dataset["created"]))
entry = DatasetEntry(
created=created,
id=dataset["id"],
preferred_border_agent_id=None,
preferred_extended_address=None,
source=dataset["source"],
tlv=dataset["tlv"],
)
if (
MeshcopTLVType.EXTPANID not in entry.dataset
or MeshcopTLVType.ACTIVETIMESTAMP not in entry.dataset
):
_LOGGER.warning(
"Dropped invalid Thread dataset:\n%s",
pformat(_format_dataset(entry.dataset)),
)
if entry.id == preferred_dataset:
preferred_dataset = None
continue
if entry.extended_pan_id in datasets:
if datasets[entry.extended_pan_id].id == preferred_dataset:
_LOGGER.warning(
"Dropped duplicated Thread dataset"
" (duplicate of preferred dataset):\n%s\nkept:\n%s",
pformat(_format_dataset(entry.dataset)),
pformat(
_format_dataset(
datasets[entry.extended_pan_id].dataset
)
),
)
continue
new_timestamp = cast(
tlv_parser.Timestamp,
entry.dataset[MeshcopTLVType.ACTIVETIMESTAMP],
)
old_timestamp = cast(
tlv_parser.Timestamp,
datasets[entry.extended_pan_id].dataset[
MeshcopTLVType.ACTIVETIMESTAMP
],
)
if (old_timestamp.seconds, old_timestamp.ticks) >= (
new_timestamp.seconds,
new_timestamp.ticks,
):
_LOGGER.warning(
"Dropped duplicated Thread dataset:\n%s\nkept:\n%s",
pformat(_format_dataset(entry.dataset)),
pformat(
_format_dataset(
datasets[entry.extended_pan_id].dataset
)
),
)
continue
_LOGGER.warning(
"Dropped duplicated Thread dataset:\n%s\nkept:\n%s",
pformat(
_format_dataset(datasets[entry.extended_pan_id].dataset)
),
pformat(_format_dataset(entry.dataset)),
)
datasets[entry.extended_pan_id] = entry
data = {
"preferred_dataset": preferred_dataset,
"datasets": [dataset.to_json() for dataset in datasets.values()],
}
# Migration to version 1.3 removed, it added the ID of the preferred border
# agent
if old_minor_version < 4:
# Add extended address of the preferred border agent and clear border
# agent ID
for dataset in data["datasets"]:
dataset["preferred_border_agent_id"] = None
dataset["preferred_extended_address"] = None
return data
class DatasetStore:
"""Class to hold a collection of thread datasets."""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the dataset store."""
self.hass = hass
self.datasets: dict[str, DatasetEntry] = {}
self._preferred_dataset: str | None = None
self._set_preferred_dataset_task: Task | None = None
self._store: Store[dict[str, Any]] = DatasetStoreStore(
hass,
STORAGE_VERSION_MAJOR,
STORAGE_KEY,
atomic_writes=True,
minor_version=STORAGE_VERSION_MINOR,
)
@callback
def async_add(
self,
source: str,
tlv: str,
preferred_border_agent_id: str | None,
preferred_extended_address: str | None,
) -> None:
"""Add dataset, does nothing if it already exists."""
# Make sure the tlv is valid
dataset = tlv_parser.parse_tlv(tlv)
# Don't allow adding a dataset which does not have an extended pan id or
# timestamp
if (
MeshcopTLVType.EXTPANID not in dataset
or MeshcopTLVType.ACTIVETIMESTAMP not in dataset
):
raise HomeAssistantError("Invalid dataset")
# Don't allow setting preferred border agent ID without setting
# preferred extended address
if preferred_border_agent_id is not None and preferred_extended_address is None:
raise HomeAssistantError(
"Must set preferred extended address with preferred border agent ID"
)
# Bail out if the dataset already exists
entry: DatasetEntry | None
for entry in self.datasets.values():
if entry.dataset == dataset:
if (
preferred_extended_address
and entry.preferred_extended_address is None
):
self.async_set_preferred_border_agent(
entry.id, preferred_border_agent_id, preferred_extended_address
)
return
# Update if dataset with same extended pan id exists and the timestamp
# is newer
if entry := next(
(
entry
for entry in self.datasets.values()
if entry.dataset[MeshcopTLVType.EXTPANID]
== dataset[MeshcopTLVType.EXTPANID]
),
None,
):
new_timestamp = cast(
tlv_parser.Timestamp, dataset[MeshcopTLVType.ACTIVETIMESTAMP]
)
old_timestamp = cast(
tlv_parser.Timestamp,
entry.dataset[MeshcopTLVType.ACTIVETIMESTAMP],
)
old_ts = (old_timestamp.seconds, old_timestamp.ticks)
new_ts = (new_timestamp.seconds, new_timestamp.ticks)
if old_ts >= new_ts:
# Silently accept if the only addition is WAKEUP_CHANNEL:
# it was added in OpenThread but the wake-up protocol isn't
# defined yet, so we treat it as if it were always present.
dataset_without_wakeup = {
k: v
for k, v in dataset.items()
if k != MeshcopTLVType.WAKEUP_CHANNEL
}
if old_ts > new_ts or dataset_without_wakeup != entry.dataset:
_LOGGER.warning(
"Got dataset with same extended PAN ID and same or older"
" active timestamp\nold:\n%s\nnew:\n%s",
pformat(_format_dataset(entry.dataset)),
pformat(_format_dataset(dataset)),
)
return
elif _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
"Updating dataset with same extended PAN ID and newer"
" active timestamp\nold:\n%s\nnew:\n%s",
pformat(_format_dataset(entry.dataset)),
pformat(_format_dataset(dataset)),
)
self.datasets[entry.id] = dataclasses.replace(
self.datasets[entry.id], tlv=tlv
)
self.async_schedule_save()
if preferred_extended_address and entry.preferred_extended_address is None:
self.async_set_preferred_border_agent(
entry.id, preferred_border_agent_id, preferred_extended_address
)
return
entry = DatasetEntry(
preferred_border_agent_id=preferred_border_agent_id,
preferred_extended_address=preferred_extended_address,
source=source,
tlv=tlv,
)
self.datasets[entry.id] = entry
self.async_schedule_save()
# Set the new network as preferred if there is no preferred dataset and there is
# no other router present. We only attempt this once.
if (
self._preferred_dataset is None
and preferred_extended_address
and not self._set_preferred_dataset_task
):
self._set_preferred_dataset_task = self.hass.async_create_task(
self._set_preferred_dataset_if_only_network(
entry.id, preferred_extended_address
)
)
@callback
def async_delete(self, dataset_id: str) -> None:
"""Delete dataset."""
if self._preferred_dataset == dataset_id:
raise DatasetPreferredError("attempt to remove preferred dataset")
del self.datasets[dataset_id]
self.async_schedule_save()
@callback
def async_get(self, dataset_id: str) -> DatasetEntry | None:
"""Get dataset by id."""
return self.datasets.get(dataset_id)
@callback
def async_set_preferred_border_agent(
self, dataset_id: str, border_agent_id: str | None, extended_address: str
) -> None:
"""Set preferred border agent id and extended address of a dataset."""
# Don't allow setting preferred border agent ID without setting
# preferred extended address
if border_agent_id is not None and extended_address is None:
raise HomeAssistantError(
"Must set preferred extended address with preferred border agent ID"
)
self.datasets[dataset_id] = dataclasses.replace(
self.datasets[dataset_id],
preferred_border_agent_id=border_agent_id,
preferred_extended_address=extended_address,
)
self.async_schedule_save()
@property
@callback
def preferred_dataset(self) -> str | None:
"""Get the id of the preferred dataset."""
return self._preferred_dataset
@preferred_dataset.setter
@callback
def preferred_dataset(self, dataset_id: str) -> None:
"""Set the preferred dataset."""
if dataset_id not in self.datasets:
raise KeyError("unknown dataset")
self._preferred_dataset = dataset_id
self.async_schedule_save()
async def _set_preferred_dataset_if_only_network(
self, dataset_id: str, extended_address: str | None
) -> None:
"""Set the preferred dataset, unless there are other routers present."""
_LOGGER.debug(
"_set_preferred_dataset_if_only_network called for router %s",
extended_address,
)
own_router_evt = Event()
other_router_evt = Event()
@callback
def router_discovered(
key: str, data: discovery.ThreadRouterDiscoveryData
) -> None:
"""Handle router discovered."""
_LOGGER.debug("discovered router with ext addr %s", data.extended_address)
if data.extended_address == extended_address:
own_router_evt.set()
return
other_router_evt.set()
# Start Thread router discovery
thread_discovery = discovery.ThreadRouterDiscovery(
self.hass, router_discovered, lambda key: None
)
await thread_discovery.async_start()
found_own_router = self.hass.async_create_task(own_router_evt.wait())
found_other_router = self.hass.async_create_task(other_router_evt.wait())
pending = {found_own_router, found_other_router}
(done, pending) = await wait(pending, timeout=BORDER_AGENT_DISCOVERY_TIMEOUT)
if found_other_router in done:
# We found another router on the network, don't set the dataset
# as preferred
_LOGGER.debug("Other router found, do not set dataset as default")
# Note that asyncio.wait does not raise TimeoutError, it instead returns
# the jobs which did not finish in the pending-set.
elif found_own_router in pending:
# Either the router is not there, or mDNS is not working. In any case,
# don't set the router as preferred.
_LOGGER.debug("Own router not found, do not set dataset as default")
else:
# We've discovered the router connected to the dataset, but we did not
# find any other router on the network - mark the dataset as preferred.
_LOGGER.debug("No other router found, set dataset as default")
self.preferred_dataset = dataset_id
for task in pending:
task.cancel()
await thread_discovery.async_stop()
async def async_load(self) -> None:
"""Load the datasets."""
data = await self._store.async_load()
datasets: dict[str, DatasetEntry] = {}
preferred_dataset: str | None = None
if data is not None:
for dataset in data["datasets"]:
created = cast(datetime, dt_util.parse_datetime(dataset["created"]))
datasets[dataset["id"]] = DatasetEntry(
created=created,
id=dataset["id"],
preferred_border_agent_id=dataset["preferred_border_agent_id"],
preferred_extended_address=dataset["preferred_extended_address"],
source=dataset["source"],
tlv=dataset["tlv"],
)
preferred_dataset = data["preferred_dataset"]
self.datasets = datasets
self._preferred_dataset = preferred_dataset
@callback
def async_schedule_save(self) -> None:
"""Schedule saving the dataset store."""
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self) -> dict[str, list[dict[str, str | None]]]:
"""Return data of datasets to store in a file."""
data: dict[str, Any] = {}
data["datasets"] = [dataset.to_json() for dataset in self.datasets.values()]
data["preferred_dataset"] = self._preferred_dataset
return data
@singleton(DATA_STORE)
async def async_get_store(hass: HomeAssistant) -> DatasetStore:
"""Get the dataset store."""
store = DatasetStore(hass)
await store.async_load()
return store
async def async_add_dataset(
hass: HomeAssistant,
source: str,
tlv: str,
*,
preferred_border_agent_id: str | None = None,
preferred_extended_address: str | None = None,
) -> None:
"""Add a dataset."""
store = await async_get_store(hass)
store.async_add(source, tlv, preferred_border_agent_id, preferred_extended_address)
async def async_get_dataset(hass: HomeAssistant, dataset_id: str) -> str | None:
"""Get a dataset."""
store = await async_get_store(hass)
if (entry := store.async_get(dataset_id)) is None:
return None
return entry.tlv
async def async_get_preferred_dataset(hass: HomeAssistant) -> str | None:
"""Get the preferred dataset."""
store = await async_get_store(hass)
if (preferred_dataset := store.preferred_dataset) is None or (
entry := store.async_get(preferred_dataset)
) is None:
return None
return entry.tlv