mirror of
https://github.com/home-assistant/core.git
synced 2025-12-24 21:06:19 +00:00
Protect internal coordinator state (#153685)
This commit is contained in:
@@ -31,7 +31,7 @@ async def async_setup_entry(
|
||||
for location_id, location in coordinator.data["locations"].items()
|
||||
]
|
||||
|
||||
async_add_entities(alarms, True)
|
||||
async_add_entities(alarms)
|
||||
|
||||
|
||||
class CanaryAlarm(
|
||||
|
||||
@@ -68,8 +68,7 @@ async def async_setup_entry(
|
||||
for location_id, location in coordinator.data["locations"].items()
|
||||
for device in location.devices
|
||||
if device.is_online
|
||||
),
|
||||
True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ async def async_setup_entry(
|
||||
if device_type.get("name") in sensor_type[4]
|
||||
)
|
||||
|
||||
async_add_entities(sensors, True)
|
||||
async_add_entities(sensors)
|
||||
|
||||
|
||||
class CanarySensor(CoordinatorEntity[CanaryDataUpdateCoordinator], SensorEntity):
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from logging import Logger
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.core import HassJob, HomeAssistant, callback
|
||||
|
||||
@@ -36,6 +38,7 @@ class Debouncer[_R_co]:
|
||||
self._timer_task: asyncio.TimerHandle | None = None
|
||||
self._execute_at_end_of_timer: bool = False
|
||||
self._execute_lock = asyncio.Lock()
|
||||
self._execute_lock_owner: asyncio.Task[Any] | None = None
|
||||
self._background = background
|
||||
self._job: HassJob[[], _R_co] | None = (
|
||||
None
|
||||
@@ -46,6 +49,22 @@ class Debouncer[_R_co]:
|
||||
)
|
||||
self._shutdown_requested = False
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_lock(self) -> AsyncGenerator[None]:
|
||||
"""Return an async context manager to lock the debouncer."""
|
||||
if self._execute_lock_owner is asyncio.current_task():
|
||||
raise RuntimeError("Debouncer lock is not re-entrant")
|
||||
|
||||
if self._execute_lock.locked():
|
||||
self.logger.debug("Debouncer lock is already acquired, waiting")
|
||||
|
||||
async with self._execute_lock:
|
||||
self._execute_lock_owner = asyncio.current_task()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._execute_lock_owner = None
|
||||
|
||||
@property
|
||||
def function(self) -> Callable[[], _R_co] | None:
|
||||
"""Return the function being wrapped by the Debouncer."""
|
||||
@@ -98,7 +117,7 @@ class Debouncer[_R_co]:
|
||||
if not self._async_schedule_or_call_now():
|
||||
return
|
||||
|
||||
async with self._execute_lock:
|
||||
async with self.async_lock():
|
||||
# Abort if timer got set while we're waiting for the lock.
|
||||
if self._timer_task:
|
||||
return
|
||||
@@ -122,7 +141,7 @@ class Debouncer[_R_co]:
|
||||
if self._execute_lock.locked():
|
||||
return
|
||||
|
||||
async with self._execute_lock:
|
||||
async with self.async_lock():
|
||||
# Abort if timer got set while we're waiting for the lock.
|
||||
if self._timer_task:
|
||||
return
|
||||
|
||||
@@ -128,10 +128,10 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
|
||||
logger,
|
||||
cooldown=REQUEST_REFRESH_DEFAULT_COOLDOWN,
|
||||
immediate=REQUEST_REFRESH_DEFAULT_IMMEDIATE,
|
||||
function=self.async_refresh,
|
||||
function=self._async_refresh,
|
||||
)
|
||||
else:
|
||||
request_refresh_debouncer.function = self.async_refresh
|
||||
request_refresh_debouncer.function = self._async_refresh
|
||||
|
||||
self._debounced_refresh = request_refresh_debouncer
|
||||
|
||||
@@ -277,7 +277,8 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
|
||||
async def _handle_refresh_interval(self, _now: datetime | None = None) -> None:
|
||||
"""Handle a refresh interval occurrence."""
|
||||
self._unsub_refresh = None
|
||||
await self._async_refresh(log_failures=True, scheduled=True)
|
||||
async with self._debounced_refresh.async_lock():
|
||||
await self._async_refresh(log_failures=True, scheduled=True)
|
||||
|
||||
async def async_request_refresh(self) -> None:
|
||||
"""Request a refresh.
|
||||
@@ -295,6 +296,16 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
|
||||
async def async_config_entry_first_refresh(self) -> None:
|
||||
"""Refresh data for the first time when a config entry is setup.
|
||||
|
||||
Will automatically raise ConfigEntryNotReady if the refresh
|
||||
fails. Additionally logging is handled by config entry setup
|
||||
to ensure that multiple retries do not cause log spam.
|
||||
"""
|
||||
async with self._debounced_refresh.async_lock():
|
||||
await self._async_config_entry_first_refresh()
|
||||
|
||||
async def _async_config_entry_first_refresh(self) -> None:
|
||||
"""Refresh data for the first time when a config entry is setup.
|
||||
|
||||
Will automatically raise ConfigEntryNotReady if the refresh
|
||||
fails. Additionally logging is handled by config entry setup
|
||||
to ensure that multiple retries do not cause log spam.
|
||||
@@ -364,7 +375,8 @@ class DataUpdateCoordinator(BaseDataUpdateCoordinatorProtocol, Generic[_DataT]):
|
||||
|
||||
async def async_refresh(self) -> None:
|
||||
"""Refresh data and log errors."""
|
||||
await self._async_refresh(log_failures=True)
|
||||
async with self._debounced_refresh.async_lock():
|
||||
await self._async_refresh(log_failures=True)
|
||||
|
||||
async def _async_refresh( # noqa: C901
|
||||
self,
|
||||
|
||||
@@ -19,7 +19,6 @@ from homeassistant.const import (
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
from homeassistant.helpers.entity_component import async_update_entity
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from . import init_integration, mock_device, mock_location, mock_reading
|
||||
@@ -126,8 +125,7 @@ async def test_sensors_attributes_pro(hass: HomeAssistant, canary) -> None:
|
||||
|
||||
future = utcnow() + timedelta(seconds=30)
|
||||
async_fire_time_changed(hass, future)
|
||||
await async_update_entity(hass, entity_id)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
|
||||
state2 = hass.states.get(entity_id)
|
||||
assert state2
|
||||
@@ -142,8 +140,7 @@ async def test_sensors_attributes_pro(hass: HomeAssistant, canary) -> None:
|
||||
|
||||
future += timedelta(seconds=30)
|
||||
async_fire_time_changed(hass, future)
|
||||
await async_update_entity(hass, entity_id)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
|
||||
state3 = hass.states.get(entity_id)
|
||||
assert state3
|
||||
|
||||
@@ -922,7 +922,7 @@ async def test_coordinator_updates(
|
||||
supervisor_client.refresh_updates.assert_not_called()
|
||||
|
||||
async_fire_time_changed(hass, dt_util.now() + timedelta(minutes=20))
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
|
||||
# Scheduled refresh, no update refresh call
|
||||
supervisor_client.refresh_updates.assert_not_called()
|
||||
@@ -944,7 +944,7 @@ async def test_coordinator_updates(
|
||||
async_fire_time_changed(
|
||||
hass, dt_util.now() + timedelta(seconds=REQUEST_REFRESH_DELAY)
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
supervisor_client.refresh_updates.assert_called_once()
|
||||
|
||||
supervisor_client.refresh_updates.reset_mock()
|
||||
|
||||
@@ -157,8 +157,8 @@ async def test_trophy_title_coordinator_auth_failed(
|
||||
|
||||
freezer.tick(timedelta(days=1))
|
||||
async_fire_time_changed(hass)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
assert len(flows) == 1
|
||||
@@ -194,8 +194,8 @@ async def test_trophy_title_coordinator_update_data_failed(
|
||||
|
||||
freezer.tick(timedelta(days=1))
|
||||
async_fire_time_changed(hass)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
|
||||
runtime_data: PlaystationNetworkRuntimeData = config_entry.runtime_data
|
||||
assert runtime_data.trophy_titles.last_update_success is False
|
||||
@@ -254,8 +254,8 @@ async def test_trophy_title_coordinator_play_new_game(
|
||||
|
||||
freezer.tick(timedelta(days=1))
|
||||
async_fire_time_changed(hass)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
|
||||
assert len(mock_psnawpapi.user.return_value.trophy_titles.mock_calls) == 2
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for the update coordinator."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
@@ -405,6 +406,70 @@ async def test_update_interval_not_present(
|
||||
assert crd.data is None
|
||||
|
||||
|
||||
async def test_update_locks(
|
||||
hass: HomeAssistant,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
crd: update_coordinator.DataUpdateCoordinator[int],
|
||||
) -> None:
|
||||
"""Test update interval works."""
|
||||
start = asyncio.Event()
|
||||
block = asyncio.Event()
|
||||
|
||||
async def _update_method() -> int:
|
||||
start.set()
|
||||
await block.wait()
|
||||
block.clear()
|
||||
return 0
|
||||
|
||||
crd.update_method = _update_method
|
||||
|
||||
# Add subscriber
|
||||
update_callback = Mock()
|
||||
crd.async_add_listener(update_callback)
|
||||
|
||||
assert crd.update_interval
|
||||
|
||||
# Trigger timed update, ensure it is started
|
||||
freezer.tick(crd.update_interval)
|
||||
async_fire_time_changed(hass)
|
||||
await start.wait()
|
||||
start.clear()
|
||||
|
||||
# Trigger direct update
|
||||
task = hass.async_create_background_task(crd.async_refresh(), "", eager_start=True)
|
||||
freezer.tick(timedelta(seconds=60))
|
||||
async_fire_time_changed(hass)
|
||||
|
||||
# Ensure it has not started
|
||||
assert not start.is_set()
|
||||
|
||||
# Unblock interval update
|
||||
block.set()
|
||||
|
||||
# Check that direct update starts
|
||||
await start.wait()
|
||||
start.clear()
|
||||
|
||||
# Request update. This should not be blocking
|
||||
# since the lock is held, it should be queued
|
||||
await crd.async_request_refresh()
|
||||
assert not start.is_set()
|
||||
|
||||
# Unblock second update
|
||||
block.set()
|
||||
# Check that task finishes
|
||||
await task
|
||||
|
||||
# Check that queued update starts
|
||||
freezer.tick(timedelta(seconds=60))
|
||||
async_fire_time_changed(hass)
|
||||
await start.wait()
|
||||
start.clear()
|
||||
|
||||
# Unblock queued update
|
||||
block.set()
|
||||
|
||||
|
||||
async def test_refresh_recover(
|
||||
crd: update_coordinator.DataUpdateCoordinator[int], caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user