1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-24 21:06:19 +00:00

Protect internal coordinator state (#153685)

This commit is contained in:
Joakim Plate
2025-10-14 14:14:37 +02:00
committed by GitHub
parent 21f24c2f6a
commit d140eb4c76
9 changed files with 116 additions and 24 deletions

View File

@@ -31,7 +31,7 @@ async def async_setup_entry(
for location_id, location in coordinator.data["locations"].items()
]
async_add_entities(alarms, True)
async_add_entities(alarms)
class CanaryAlarm(

View File

@@ -68,8 +68,7 @@ async def async_setup_entry(
for location_id, location in coordinator.data["locations"].items()
for device in location.devices
if device.is_online
),
True,
)
)

View File

@@ -80,7 +80,7 @@ async def async_setup_entry(
if device_type.get("name") in sensor_type[4]
)
async_add_entities(sensors, True)
async_add_entities(sensors)
class CanarySensor(CoordinatorEntity[CanaryDataUpdateCoordinator], SensorEntity):

View File

@@ -3,8 +3,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from collections.abc import AsyncGenerator, Callable
from contextlib import asynccontextmanager
from logging import Logger
from typing import Any
from homeassistant.core import HassJob, HomeAssistant, callback
@@ -36,6 +38,7 @@ class Debouncer[_R_co]:
self._timer_task: asyncio.TimerHandle | None = None
self._execute_at_end_of_timer: bool = False
self._execute_lock = asyncio.Lock()
self._execute_lock_owner: asyncio.Task[Any] | None = None
self._background = background
self._job: HassJob[[], _R_co] | None = (
None
@@ -46,6 +49,22 @@ class Debouncer[_R_co]:
)
self._shutdown_requested = False
@asynccontextmanager
async def async_lock(self) -> AsyncGenerator[None]:
"""Return an async context manager to lock the debouncer."""
if self._execute_lock_owner is asyncio.current_task():
raise RuntimeError("Debouncer lock is not re-entrant")
if self._execute_lock.locked():
self.logger.debug("Debouncer lock is already acquired, waiting")
async with self._execute_lock:
self._execute_lock_owner = asyncio.current_task()
try:
yield
finally:
self._execute_lock_owner = None
@property
def function(self) -> Callable[[], _R_co] | None:
"""Return the function being wrapped by the Debouncer."""
@@ -98,7 +117,7 @@ class Debouncer[_R_co]:
if not self._async_schedule_or_call_now():
return
async with self._execute_lock:
async with self.async_lock():
# Abort if timer got set while we're waiting for the lock.
if self._timer_task:
return
@@ -122,7 +141,7 @@ class Debouncer[_R_co]:
if self._execute_lock.locked():
return
async with self._execute_lock:
async with self.async_lock():
# Abort if timer got set while we're waiting for the lock.
if self._timer_task:
return

View File

@@ -128,10 +128,10 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
logger,
cooldown=REQUEST_REFRESH_DEFAULT_COOLDOWN,
immediate=REQUEST_REFRESH_DEFAULT_IMMEDIATE,
function=self.async_refresh,
function=self._async_refresh,
)
else:
request_refresh_debouncer.function = self.async_refresh
request_refresh_debouncer.function = self._async_refresh
self._debounced_refresh = request_refresh_debouncer
@@ -277,7 +277,8 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
async def _handle_refresh_interval(self, _now: datetime | None = None) -> None:
"""Handle a refresh interval occurrence."""
self._unsub_refresh = None
await self._async_refresh(log_failures=True, scheduled=True)
async with self._debounced_refresh.async_lock():
await self._async_refresh(log_failures=True, scheduled=True)
async def async_request_refresh(self) -> None:
"""Request a refresh.
@@ -295,6 +296,16 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
async def async_config_entry_first_refresh(self) -> None:
"""Refresh data for the first time when a config entry is setup.
Will automatically raise ConfigEntryNotReady if the refresh
fails. Additionally logging is handled by config entry setup
to ensure that multiple retries do not cause log spam.
"""
async with self._debounced_refresh.async_lock():
await self._async_config_entry_first_refresh()
async def _async_config_entry_first_refresh(self) -> None:
"""Refresh data for the first time when a config entry is setup.
Will automatically raise ConfigEntryNotReady if the refresh
fails. Additionally logging is handled by config entry setup
to ensure that multiple retries do not cause log spam.
@@ -364,7 +375,8 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
async def async_refresh(self) -> None:
"""Refresh data and log errors."""
await self._async_refresh(log_failures=True)
async with self._debounced_refresh.async_lock():
await self._async_refresh(log_failures=True)
async def _async_refresh( # noqa: C901
self,

View File

@@ -19,7 +19,6 @@ from homeassistant.const import (
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.entity_component import async_update_entity
from homeassistant.util.dt import utcnow
from . import init_integration, mock_device, mock_location, mock_reading
@@ -126,8 +125,7 @@ async def test_sensors_attributes_pro(hass: HomeAssistant, canary) -> None:
future = utcnow() + timedelta(seconds=30)
async_fire_time_changed(hass, future)
await async_update_entity(hass, entity_id)
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)
state2 = hass.states.get(entity_id)
assert state2
@@ -142,8 +140,7 @@ async def test_sensors_attributes_pro(hass: HomeAssistant, canary) -> None:
future += timedelta(seconds=30)
async_fire_time_changed(hass, future)
await async_update_entity(hass, entity_id)
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)
state3 = hass.states.get(entity_id)
assert state3

View File

@@ -922,7 +922,7 @@ async def test_coordinator_updates(
supervisor_client.refresh_updates.assert_not_called()
async_fire_time_changed(hass, dt_util.now() + timedelta(minutes=20))
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)
# Scheduled refresh, no update refresh call
supervisor_client.refresh_updates.assert_not_called()
@@ -944,7 +944,7 @@ async def test_coordinator_updates(
async_fire_time_changed(
hass, dt_util.now() + timedelta(seconds=REQUEST_REFRESH_DELAY)
)
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)
supervisor_client.refresh_updates.assert_called_once()
supervisor_client.refresh_updates.reset_mock()

View File

@@ -157,8 +157,8 @@ async def test_trophy_title_coordinator_auth_failed(
freezer.tick(timedelta(days=1))
async_fire_time_changed(hass)
await hass.async_block_till_done()
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)
await hass.async_block_till_done(wait_background_tasks=True)
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
@@ -194,8 +194,8 @@ async def test_trophy_title_coordinator_update_data_failed(
freezer.tick(timedelta(days=1))
async_fire_time_changed(hass)
await hass.async_block_till_done()
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)
await hass.async_block_till_done(wait_background_tasks=True)
runtime_data: PlaystationNetworkRuntimeData = config_entry.runtime_data
assert runtime_data.trophy_titles.last_update_success is False
@@ -254,8 +254,8 @@ async def test_trophy_title_coordinator_play_new_game(
freezer.tick(timedelta(days=1))
async_fire_time_changed(hass)
await hass.async_block_till_done()
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)
await hass.async_block_till_done(wait_background_tasks=True)
assert len(mock_psnawpapi.user.return_value.trophy_titles.mock_calls) == 2

View File

@@ -1,5 +1,6 @@
"""Tests for the update coordinator."""
import asyncio
from datetime import datetime, timedelta
import logging
from unittest.mock import AsyncMock, Mock, patch
@@ -405,6 +406,70 @@ async def test_update_interval_not_present(
assert crd.data is None
async def test_update_locks(
hass: HomeAssistant,
freezer: FrozenDateTimeFactory,
crd: update_coordinator.DataUpdateCoordinator[int],
) -> None:
"""Test update interval works."""
start = asyncio.Event()
block = asyncio.Event()
async def _update_method() -> int:
start.set()
await block.wait()
block.clear()
return 0
crd.update_method = _update_method
# Add subscriber
update_callback = Mock()
crd.async_add_listener(update_callback)
assert crd.update_interval
# Trigger timed update, ensure it is started
freezer.tick(crd.update_interval)
async_fire_time_changed(hass)
await start.wait()
start.clear()
# Trigger direct update
task = hass.async_create_background_task(crd.async_refresh(), "", eager_start=True)
freezer.tick(timedelta(seconds=60))
async_fire_time_changed(hass)
# Ensure it has not started
assert not start.is_set()
# Unblock interval update
block.set()
# Check that direct update starts
await start.wait()
start.clear()
# Request update. This should not be blocking
# since the lock is held, it should be queued
await crd.async_request_refresh()
assert not start.is_set()
# Unblock second update
block.set()
# Check that task finishes
await task
# Check that queued update starts
freezer.tick(timedelta(seconds=60))
async_fire_time_changed(hass)
await start.wait()
start.clear()
# Unblock queued update
block.set()
async def test_refresh_recover(
crd: update_coordinator.DataUpdateCoordinator[int], caplog: pytest.LogCaptureFixture
) -> None: