diff --git a/homeassistant/components/sunricher_dali/__init__.py b/homeassistant/components/sunricher_dali/__init__.py index 47d4317ce97..2137480cea8 100644 --- a/homeassistant/components/sunricher_dali/__init__.py +++ b/homeassistant/components/sunricher_dali/__init__.py @@ -3,9 +3,10 @@ from __future__ import annotations import asyncio +from collections.abc import Sequence import logging -from PySrDaliGateway import DaliGateway +from PySrDaliGateway import DaliGateway, Device from PySrDaliGateway.exceptions import DaliGatewayError from homeassistant.const import ( @@ -28,6 +29,38 @@ _PLATFORMS: list[Platform] = [Platform.LIGHT, Platform.SCENE] _LOGGER = logging.getLogger(__name__) +def _remove_missing_devices( + hass: HomeAssistant, + entry: DaliCenterConfigEntry, + devices: Sequence[Device], + gateway_identifier: tuple[str, str], +) -> None: + """Detach devices that are no longer provided by the gateway.""" + device_registry = dr.async_get(hass) + known_device_ids = {device.dev_id for device in devices} + + for device_entry in dr.async_entries_for_config_entry( + device_registry, entry.entry_id + ): + if gateway_identifier in device_entry.identifiers: + continue + + domain_device_ids = { + identifier[1] + for identifier in device_entry.identifiers + if identifier[0] == DOMAIN + } + + if not domain_device_ids: + continue + + if domain_device_ids.isdisjoint(known_device_ids): + device_registry.async_update_device( + device_entry.id, + remove_config_entry_id=entry.entry_id, + ) + + async def async_setup_entry(hass: HomeAssistant, entry: DaliCenterConfigEntry) -> bool: """Set up Sunricher DALI from a config entry.""" @@ -70,6 +103,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: DaliCenterConfigEntry) - model="SR-GW-EDA", serial_number=gw_sn, ) + _remove_missing_devices(hass, entry, devices, (DOMAIN, gw_sn)) entry.runtime_data = DaliCenterData( gateway=gateway, diff --git a/tests/components/sunricher_dali/test_init.py b/tests/components/sunricher_dali/test_init.py index 1941b6313d8..2dd7db20996 100644 --- a/tests/components/sunricher_dali/test_init.py +++ b/tests/components/sunricher_dali/test_init.py @@ -5,9 +5,10 @@ from unittest.mock import MagicMock from PySrDaliGateway.exceptions import DaliGatewayError from syrupy.assertion import SnapshotAssertion +from homeassistant.components.sunricher_dali.const import DOMAIN from homeassistant.config_entries import ConfigEntryState from homeassistant.core import HomeAssistant -import homeassistant.helpers.device_registry as dr +from homeassistant.helpers import device_registry as dr from tests.common import MockConfigEntry @@ -63,6 +64,23 @@ async def test_setup_entry_connection_error( mock_gateway.connect.assert_called_once() +async def test_setup_entry_discovery_error( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_gateway: MagicMock, +) -> None: + """Test setup fails when device discovery fails.""" + mock_config_entry.add_to_hass(hass) + mock_gateway.discover_devices.side_effect = DaliGatewayError("Discovery failed") + + assert not await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY + mock_gateway.connect.assert_called_once() + mock_gateway.discover_devices.assert_called_once() + + async def test_unload_entry( hass: HomeAssistant, mock_config_entry: MockConfigEntry, @@ -80,3 +98,40 @@ async def test_unload_entry( await hass.async_block_till_done() assert mock_config_entry.state is ConfigEntryState.NOT_LOADED + + +async def test_remove_stale_devices( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_gateway: MagicMock, + mock_devices: list[MagicMock], + device_registry: dr.DeviceRegistry, +) -> None: + """Test stale devices are removed when device list decreases.""" + mock_config_entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + devices_before = dr.async_entries_for_config_entry( + device_registry, mock_config_entry.entry_id + ) + initial_count = len(devices_before) + + assert await hass.config_entries.async_unload(mock_config_entry.entry_id) + await hass.async_block_till_done() + + mock_gateway.discover_devices.return_value = mock_devices[:2] + + assert await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + devices_after = dr.async_entries_for_config_entry( + device_registry, mock_config_entry.entry_id + ) + assert len(devices_after) < initial_count + + gateway_device = device_registry.async_get_device( + identifiers={(DOMAIN, mock_gateway.gw_sn)} + ) + assert gateway_device is not None + assert mock_config_entry.entry_id in gateway_device.config_entries