diff --git a/homeassistant/components/unifi/coordinator.py b/homeassistant/components/unifi/coordinator.py new file mode 100644 index 00000000000..9b840d77132 --- /dev/null +++ b/homeassistant/components/unifi/coordinator.py @@ -0,0 +1,45 @@ +"""UniFi Network data update coordinator.""" + +from __future__ import annotations + +from datetime import timedelta +from typing import TYPE_CHECKING + +from aiounifi.interfaces.api_handlers import APIHandler + +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator + +from .const import LOGGER + +if TYPE_CHECKING: + from .hub.hub import UnifiHub + +POLL_INTERVAL = timedelta(seconds=10) + + +class UnifiDataUpdateCoordinator[HandlerT: APIHandler](DataUpdateCoordinator[None]): + """Coordinator managing polling for a single UniFi API data source.""" + + def __init__( + self, + hub: UnifiHub, + handler: HandlerT, + ) -> None: + """Initialize coordinator.""" + super().__init__( + hub.hass, + LOGGER, + name=f"UniFi {type(handler).__name__}", + config_entry=hub.config.entry, + update_interval=POLL_INTERVAL, + ) + self._handler = handler + + @property + def handler(self) -> HandlerT: + """Return the aiounifi handler managed by this coordinator.""" + return self._handler + + async def _async_update_data(self) -> None: + """Update data from the API handler.""" + await self._handler.update() diff --git a/homeassistant/components/unifi/entity.py b/homeassistant/components/unifi/entity.py index 4b68287ce10..03fae17f689 100644 --- a/homeassistant/components/unifi/entity.py +++ b/homeassistant/components/unifi/entity.py @@ -94,16 +94,14 @@ def async_client_device_info_fn(hub: UnifiHub, obj_id: str) -> DeviceInfo: @dataclass(frozen=True, kw_only=True) -class UnifiEntityDescription[HandlerT: APIHandler, ApiItemT: ApiItem]( - EntityDescription -): +class UnifiEntityDescription[HandlerT: APIHandler, ItemT: ApiItem](EntityDescription): """UniFi Entity Description.""" api_handler_fn: Callable[[aiounifi.Controller], HandlerT] """Provide api_handler from api.""" device_info_fn: Callable[[UnifiHub, str], DeviceInfo | None] """Provide device info object based on hub and obj_id.""" - object_fn: Callable[[aiounifi.Controller, str], ApiItemT] + object_fn: Callable[[aiounifi.Controller, str], ItemT] """Retrieve object based on api and obj_id.""" unique_id_fn: Callable[[UnifiHub, str], str] """Provide a unique ID based on hub and obj_id.""" @@ -113,7 +111,7 @@ class UnifiEntityDescription[HandlerT: APIHandler, ApiItemT: ApiItem]( """Determine if config entry options allow creation of entity.""" available_fn: Callable[[UnifiHub, str], bool] = lambda hub, obj_id: hub.available """Determine if entity is available, default is if connection is working.""" - name_fn: Callable[[ApiItemT], str | None] = lambda obj: None + name_fn: Callable[[ItemT], str | None] = lambda obj: None """Entity name function, can be used to extend entity name beyond device name.""" supported_fn: Callable[[UnifiHub, str], bool] = lambda hub, obj_id: True """Determine if UniFi object supports providing relevant data for entity.""" @@ -129,17 +127,17 @@ class UnifiEntityDescription[HandlerT: APIHandler, ApiItemT: ApiItem]( """If entity needs to do regular checks on state.""" -class UnifiEntity[HandlerT: APIHandler, ApiItemT: ApiItem](Entity): +class UnifiEntity[HandlerT: APIHandler, ItemT: ApiItem](Entity): """Representation of a UniFi entity.""" - entity_description: UnifiEntityDescription[HandlerT, ApiItemT] + entity_description: UnifiEntityDescription[HandlerT, ItemT] _attr_unique_id: str def __init__( self, obj_id: str, hub: UnifiHub, - description: UnifiEntityDescription[HandlerT, ApiItemT], + description: UnifiEntityDescription[HandlerT, ItemT], ) -> None: """Set up UniFi switch entity.""" self._obj_id = obj_id @@ -258,6 +256,11 @@ class UnifiEntity[HandlerT: APIHandler, ApiItemT: ApiItem](Entity): """ self.async_update_state(ItemEvent.ADDED, self._obj_id) + @callback + def get_object(self) -> ItemT: + """Return the latest object for this entity.""" + return self.entity_description.object_fn(self.api, self._obj_id) + @callback @abstractmethod def async_update_state(self, event: ItemEvent, obj_id: str) -> None: diff --git a/homeassistant/components/unifi/hub/entity_loader.py b/homeassistant/components/unifi/hub/entity_loader.py index 4fd3d34a51d..3400e707ba2 100644 --- a/homeassistant/components/unifi/hub/entity_loader.py +++ b/homeassistant/components/unifi/hub/entity_loader.py @@ -12,30 +12,28 @@ from datetime import timedelta from functools import partial from typing import TYPE_CHECKING, Any -from aiounifi.interfaces.api_handlers import ItemEvent +from aiounifi.interfaces.api_handlers import APIHandler, ItemEvent from homeassistant.const import Platform from homeassistant.core import callback from homeassistant.helpers import entity_registry as er from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from ..const import LOGGER, UNIFI_WIRELESS_CLIENTS +from ..coordinator import UnifiDataUpdateCoordinator from ..entity import UnifiEntity, UnifiEntityDescription if TYPE_CHECKING: - from .. import UnifiConfigEntry from .hub import UnifiHub CHECK_HEARTBEAT_INTERVAL = timedelta(seconds=1) -POLL_INTERVAL = timedelta(seconds=10) class UnifiEntityLoader: """UniFi Network integration handling platforms for entity registration.""" - def __init__(self, hub: UnifiHub, config_entry: UnifiConfigEntry) -> None: + def __init__(self, hub: UnifiHub) -> None: """Initialize the UniFi entity loader.""" self.hub = hub self.api_updaters = ( @@ -48,28 +46,20 @@ class UnifiEntityLoader: hub.api.sites.update, hub.api.system_information.update, hub.api.firewall_policies.update, - hub.api.traffic_rules.update, - hub.api.traffic_routes.update, hub.api.wlans.update, ) - self.polling_api_updaters = ( - hub.api.traffic_rules.update, - hub.api.traffic_routes.update, - ) self.wireless_clients = hub.hass.data[UNIFI_WIRELESS_CLIENTS] - self._data_update_coordinator = DataUpdateCoordinator( - hub.hass, - LOGGER, - name="Unifi entity poller", - config_entry=config_entry, - update_method=self._update_pollable_api_data, - update_interval=POLL_INTERVAL, - ) - - self._update_listener = self._data_update_coordinator.async_add_listener( - update_callback=lambda: None - ) + self._polling_coordinators: dict[int, UnifiDataUpdateCoordinator] = { + id(hub.api.traffic_rules): UnifiDataUpdateCoordinator( + hub, hub.api.traffic_rules + ), + id(hub.api.traffic_routes): UnifiDataUpdateCoordinator( + hub, hub.api.traffic_routes + ), + } + for coordinator in self._polling_coordinators.values(): + coordinator.async_add_listener(lambda: None) self.platforms: list[ tuple[ @@ -85,7 +75,15 @@ class UnifiEntityLoader: async def initialize(self) -> None: """Initialize API data and extra client support.""" - await self._refresh_api_data() + await asyncio.gather( + self._refresh_api_data(), + self._refresh_data( + [ + coordinator.async_refresh + for coordinator in self._polling_coordinators.values() + ] + ), + ) self._restore_inactive_clients() self.wireless_clients.update_clients(set(self.hub.api.clients.values())) @@ -100,10 +98,6 @@ class UnifiEntityLoader: if result is not None: LOGGER.warning("Exception on update %s", result) - async def _update_pollable_api_data(self) -> None: - """Refresh API data for pollable updaters.""" - await self._refresh_data(self.polling_api_updaters) - async def _refresh_api_data(self) -> None: """Refresh API data from network application.""" await self._refresh_data(self.api_updaters) @@ -165,6 +159,13 @@ class UnifiEntityLoader: and description.supported_fn(self.hub, obj_id) ) + @callback + def get_data_update_coordinator( + self, handler: APIHandler + ) -> UnifiDataUpdateCoordinator | None: + """Return the polling coordinator for a handler, if available.""" + return self._polling_coordinators.get(id(handler)) + @callback def _load_entities( self, diff --git a/homeassistant/components/unifi/hub/hub.py b/homeassistant/components/unifi/hub/hub.py index 9ea887bdb29..6cf8825a26c 100644 --- a/homeassistant/components/unifi/hub/hub.py +++ b/homeassistant/components/unifi/hub/hub.py @@ -39,7 +39,7 @@ class UnifiHub: self.hass = hass self.api = api self.config = UnifiConfig.from_config_entry(config_entry) - self.entity_loader = UnifiEntityLoader(self, config_entry) + self.entity_loader = UnifiEntityLoader(self) self._entity_helper = UnifiEntityHelper(hass, api) self.websocket = UnifiWebsocket(hass, api, self.signal_reachable) diff --git a/homeassistant/components/unifi/switch.py b/homeassistant/components/unifi/switch.py index b9fbf48cf49..b39020204a5 100644 --- a/homeassistant/components/unifi/switch.py +++ b/homeassistant/components/unifi/switch.py @@ -208,8 +208,6 @@ async def async_traffic_rule_control_fn( """Control traffic rule state.""" traffic_rule = hub.api.traffic_rules[obj_id].raw await hub.api.request(TrafficRuleEnableRequest.create(traffic_rule, target)) - # Update the traffic rules so the UI is updated appropriately - await hub.api.traffic_rules.update() async def async_traffic_route_control_fn( @@ -218,8 +216,6 @@ async def async_traffic_route_control_fn( """Control traffic route state.""" traffic_route = hub.api.traffic_routes[obj_id].raw await hub.api.request(TrafficRouteSaveRequest.create(traffic_route, target)) - # Update the traffic routes so the UI is updated appropriately - await hub.api.traffic_routes.update() async def async_wlan_control_fn(hub: UnifiHub, obj_id: str, target: bool) -> None: @@ -447,10 +443,18 @@ class UnifiSwitchEntity[HandlerT: APIHandler, ApiItemT: ApiItem]( async def async_turn_on(self, **kwargs: Any) -> None: """Turn on switch.""" await self.entity_description.control_fn(self.hub, self._obj_id, True) + if coordinator := self.hub.entity_loader.get_data_update_coordinator( + self.entity_description.api_handler_fn(self.api) + ): + await coordinator.async_request_refresh() async def async_turn_off(self, **kwargs: Any) -> None: """Turn off switch.""" await self.entity_description.control_fn(self.hub, self._obj_id, False) + if coordinator := self.hub.entity_loader.get_data_update_coordinator( + self.entity_description.api_handler_fn(self.api) + ): + await coordinator.async_request_refresh() @callback def async_update_state( @@ -464,7 +468,7 @@ class UnifiSwitchEntity[HandlerT: APIHandler, ApiItemT: ApiItem]( return description = self.entity_description - obj = description.object_fn(self.api, self._obj_id) + obj = self.get_object() if (is_on := description.is_on_fn(self.hub, obj)) != self.is_on: self._attr_is_on = is_on diff --git a/tests/components/unifi/test_switch.py b/tests/components/unifi/test_switch.py index d95a87d61f9..33b23d421f3 100644 --- a/tests/components/unifi/test_switch.py +++ b/tests/components/unifi/test_switch.py @@ -1223,7 +1223,7 @@ async def test_traffic_rules( expected_enable_call = deepcopy(traffic_rule) expected_enable_call["enabled"] = True - assert aioclient_mock.call_count == call_count + 2 + assert aioclient_mock.call_count == call_count + 1 assert aioclient_mock.mock_calls[call_count][2] == expected_enable_call @@ -1277,7 +1277,7 @@ async def test_traffic_routes( expected_enable_call = deepcopy(traffic_route) expected_enable_call["enabled"] = True - assert aioclient_mock.call_count == call_count + 2 + assert aioclient_mock.call_count == call_count + 1 assert aioclient_mock.mock_calls[call_count][2] == expected_enable_call