1
0
mirror of https://github.com/home-assistant/core.git synced 2026-04-02 00:20:30 +01:00

Introduce per-source DataUpdateCoordinator for UniFi polling data sources (#166806)

This commit is contained in:
Robert Svensson
2026-03-30 16:48:18 +02:00
committed by GitHub
parent 0a05993a4e
commit 732b170190
6 changed files with 97 additions and 44 deletions

View File

@@ -0,0 +1,45 @@
"""UniFi Network data update coordinator."""
from __future__ import annotations
from datetime import timedelta
from typing import TYPE_CHECKING
from aiounifi.interfaces.api_handlers import APIHandler
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
from .const import LOGGER
if TYPE_CHECKING:
from .hub.hub import UnifiHub
POLL_INTERVAL = timedelta(seconds=10)
class UnifiDataUpdateCoordinator[HandlerT: APIHandler](DataUpdateCoordinator[None]):
"""Coordinator managing polling for a single UniFi API data source."""
def __init__(
self,
hub: UnifiHub,
handler: HandlerT,
) -> None:
"""Initialize coordinator."""
super().__init__(
hub.hass,
LOGGER,
name=f"UniFi {type(handler).__name__}",
config_entry=hub.config.entry,
update_interval=POLL_INTERVAL,
)
self._handler = handler
@property
def handler(self) -> HandlerT:
"""Return the aiounifi handler managed by this coordinator."""
return self._handler
async def _async_update_data(self) -> None:
"""Update data from the API handler."""
await self._handler.update()

View File

@@ -94,16 +94,14 @@ def async_client_device_info_fn(hub: UnifiHub, obj_id: str) -> DeviceInfo:
@dataclass(frozen=True, kw_only=True)
class UnifiEntityDescription[HandlerT: APIHandler, ApiItemT: ApiItem](
EntityDescription
):
class UnifiEntityDescription[HandlerT: APIHandler, ItemT: ApiItem](EntityDescription):
"""UniFi Entity Description."""
api_handler_fn: Callable[[aiounifi.Controller], HandlerT]
"""Provide api_handler from api."""
device_info_fn: Callable[[UnifiHub, str], DeviceInfo | None]
"""Provide device info object based on hub and obj_id."""
object_fn: Callable[[aiounifi.Controller, str], ApiItemT]
object_fn: Callable[[aiounifi.Controller, str], ItemT]
"""Retrieve object based on api and obj_id."""
unique_id_fn: Callable[[UnifiHub, str], str]
"""Provide a unique ID based on hub and obj_id."""
@@ -113,7 +111,7 @@ class UnifiEntityDescription[HandlerT: APIHandler, ApiItemT: ApiItem](
"""Determine if config entry options allow creation of entity."""
available_fn: Callable[[UnifiHub, str], bool] = lambda hub, obj_id: hub.available
"""Determine if entity is available, default is if connection is working."""
name_fn: Callable[[ApiItemT], str | None] = lambda obj: None
name_fn: Callable[[ItemT], str | None] = lambda obj: None
"""Entity name function, can be used to extend entity name beyond device name."""
supported_fn: Callable[[UnifiHub, str], bool] = lambda hub, obj_id: True
"""Determine if UniFi object supports providing relevant data for entity."""
@@ -129,17 +127,17 @@ class UnifiEntityDescription[HandlerT: APIHandler, ApiItemT: ApiItem](
"""If entity needs to do regular checks on state."""
class UnifiEntity[HandlerT: APIHandler, ApiItemT: ApiItem](Entity):
class UnifiEntity[HandlerT: APIHandler, ItemT: ApiItem](Entity):
"""Representation of a UniFi entity."""
entity_description: UnifiEntityDescription[HandlerT, ApiItemT]
entity_description: UnifiEntityDescription[HandlerT, ItemT]
_attr_unique_id: str
def __init__(
self,
obj_id: str,
hub: UnifiHub,
description: UnifiEntityDescription[HandlerT, ApiItemT],
description: UnifiEntityDescription[HandlerT, ItemT],
) -> None:
"""Set up UniFi switch entity."""
self._obj_id = obj_id
@@ -258,6 +256,11 @@ class UnifiEntity[HandlerT: APIHandler, ApiItemT: ApiItem](Entity):
"""
self.async_update_state(ItemEvent.ADDED, self._obj_id)
@callback
def get_object(self) -> ItemT:
"""Return the latest object for this entity."""
return self.entity_description.object_fn(self.api, self._obj_id)
@callback
@abstractmethod
def async_update_state(self, event: ItemEvent, obj_id: str) -> None:

View File

@@ -12,30 +12,28 @@ from datetime import timedelta
from functools import partial
from typing import TYPE_CHECKING, Any
from aiounifi.interfaces.api_handlers import ItemEvent
from aiounifi.interfaces.api_handlers import APIHandler, ItemEvent
from homeassistant.const import Platform
from homeassistant.core import callback
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
from ..const import LOGGER, UNIFI_WIRELESS_CLIENTS
from ..coordinator import UnifiDataUpdateCoordinator
from ..entity import UnifiEntity, UnifiEntityDescription
if TYPE_CHECKING:
from .. import UnifiConfigEntry
from .hub import UnifiHub
CHECK_HEARTBEAT_INTERVAL = timedelta(seconds=1)
POLL_INTERVAL = timedelta(seconds=10)
class UnifiEntityLoader:
"""UniFi Network integration handling platforms for entity registration."""
def __init__(self, hub: UnifiHub, config_entry: UnifiConfigEntry) -> None:
def __init__(self, hub: UnifiHub) -> None:
"""Initialize the UniFi entity loader."""
self.hub = hub
self.api_updaters = (
@@ -48,28 +46,20 @@ class UnifiEntityLoader:
hub.api.sites.update,
hub.api.system_information.update,
hub.api.firewall_policies.update,
hub.api.traffic_rules.update,
hub.api.traffic_routes.update,
hub.api.wlans.update,
)
self.polling_api_updaters = (
hub.api.traffic_rules.update,
hub.api.traffic_routes.update,
)
self.wireless_clients = hub.hass.data[UNIFI_WIRELESS_CLIENTS]
self._data_update_coordinator = DataUpdateCoordinator(
hub.hass,
LOGGER,
name="Unifi entity poller",
config_entry=config_entry,
update_method=self._update_pollable_api_data,
update_interval=POLL_INTERVAL,
)
self._update_listener = self._data_update_coordinator.async_add_listener(
update_callback=lambda: None
)
self._polling_coordinators: dict[int, UnifiDataUpdateCoordinator] = {
id(hub.api.traffic_rules): UnifiDataUpdateCoordinator(
hub, hub.api.traffic_rules
),
id(hub.api.traffic_routes): UnifiDataUpdateCoordinator(
hub, hub.api.traffic_routes
),
}
for coordinator in self._polling_coordinators.values():
coordinator.async_add_listener(lambda: None)
self.platforms: list[
tuple[
@@ -85,7 +75,15 @@ class UnifiEntityLoader:
async def initialize(self) -> None:
"""Initialize API data and extra client support."""
await self._refresh_api_data()
await asyncio.gather(
self._refresh_api_data(),
self._refresh_data(
[
coordinator.async_refresh
for coordinator in self._polling_coordinators.values()
]
),
)
self._restore_inactive_clients()
self.wireless_clients.update_clients(set(self.hub.api.clients.values()))
@@ -100,10 +98,6 @@ class UnifiEntityLoader:
if result is not None:
LOGGER.warning("Exception on update %s", result)
async def _update_pollable_api_data(self) -> None:
"""Refresh API data for pollable updaters."""
await self._refresh_data(self.polling_api_updaters)
async def _refresh_api_data(self) -> None:
"""Refresh API data from network application."""
await self._refresh_data(self.api_updaters)
@@ -165,6 +159,13 @@ class UnifiEntityLoader:
and description.supported_fn(self.hub, obj_id)
)
@callback
def get_data_update_coordinator(
self, handler: APIHandler
) -> UnifiDataUpdateCoordinator | None:
"""Return the polling coordinator for a handler, if available."""
return self._polling_coordinators.get(id(handler))
@callback
def _load_entities(
self,

View File

@@ -39,7 +39,7 @@ class UnifiHub:
self.hass = hass
self.api = api
self.config = UnifiConfig.from_config_entry(config_entry)
self.entity_loader = UnifiEntityLoader(self, config_entry)
self.entity_loader = UnifiEntityLoader(self)
self._entity_helper = UnifiEntityHelper(hass, api)
self.websocket = UnifiWebsocket(hass, api, self.signal_reachable)

View File

@@ -208,8 +208,6 @@ async def async_traffic_rule_control_fn(
"""Control traffic rule state."""
traffic_rule = hub.api.traffic_rules[obj_id].raw
await hub.api.request(TrafficRuleEnableRequest.create(traffic_rule, target))
# Update the traffic rules so the UI is updated appropriately
await hub.api.traffic_rules.update()
async def async_traffic_route_control_fn(
@@ -218,8 +216,6 @@ async def async_traffic_route_control_fn(
"""Control traffic route state."""
traffic_route = hub.api.traffic_routes[obj_id].raw
await hub.api.request(TrafficRouteSaveRequest.create(traffic_route, target))
# Update the traffic routes so the UI is updated appropriately
await hub.api.traffic_routes.update()
async def async_wlan_control_fn(hub: UnifiHub, obj_id: str, target: bool) -> None:
@@ -447,10 +443,18 @@ class UnifiSwitchEntity[HandlerT: APIHandler, ApiItemT: ApiItem](
async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn on switch."""
await self.entity_description.control_fn(self.hub, self._obj_id, True)
if coordinator := self.hub.entity_loader.get_data_update_coordinator(
self.entity_description.api_handler_fn(self.api)
):
await coordinator.async_request_refresh()
async def async_turn_off(self, **kwargs: Any) -> None:
"""Turn off switch."""
await self.entity_description.control_fn(self.hub, self._obj_id, False)
if coordinator := self.hub.entity_loader.get_data_update_coordinator(
self.entity_description.api_handler_fn(self.api)
):
await coordinator.async_request_refresh()
@callback
def async_update_state(
@@ -464,7 +468,7 @@ class UnifiSwitchEntity[HandlerT: APIHandler, ApiItemT: ApiItem](
return
description = self.entity_description
obj = description.object_fn(self.api, self._obj_id)
obj = self.get_object()
if (is_on := description.is_on_fn(self.hub, obj)) != self.is_on:
self._attr_is_on = is_on

View File

@@ -1223,7 +1223,7 @@ async def test_traffic_rules(
expected_enable_call = deepcopy(traffic_rule)
expected_enable_call["enabled"] = True
assert aioclient_mock.call_count == call_count + 2
assert aioclient_mock.call_count == call_count + 1
assert aioclient_mock.mock_calls[call_count][2] == expected_enable_call
@@ -1277,7 +1277,7 @@ async def test_traffic_routes(
expected_enable_call = deepcopy(traffic_route)
expected_enable_call["enabled"] = True
assert aioclient_mock.call_count == call_count + 2
assert aioclient_mock.call_count == call_count + 1
assert aioclient_mock.mock_calls[call_count][2] == expected_enable_call