1
0
mirror of https://github.com/home-assistant/core.git synced 2026-04-02 00:20:30 +01:00

Make restore state resilient to extra_restore_state_data errors (#165086)

This commit is contained in:
Artur Pragacz
2026-03-08 10:39:53 +01:00
committed by GitHub
parent 5031323dea
commit 3154c3c962
2 changed files with 130 additions and 14 deletions

View File

@@ -181,15 +181,24 @@ class RestoreStateData:
}
# Start with the currently registered states
stored_states = [
StoredState(
current_states_by_entity_id[entity_id],
entity.extra_restore_state_data,
now,
stored_states: list[StoredState] = []
for entity_id, entity in self.entities.items():
if entity_id not in current_states_by_entity_id:
continue
try:
extra_data = entity.extra_restore_state_data
except Exception:
_LOGGER.exception(
"Error getting extra restore state data for %s", entity_id
)
continue
stored_states.append(
StoredState(
current_states_by_entity_id[entity_id],
extra_data,
now,
)
)
for entity_id, entity in self.entities.items()
if entity_id in current_states_by_entity_id
]
expiration_time = now - STATE_EXPIRATION
for entity_id, stored_state in self.last_states.items():
@@ -219,6 +228,8 @@ class RestoreStateData:
)
except HomeAssistantError as exc:
_LOGGER.error("Error saving current states", exc_info=exc)
except Exception:
_LOGGER.exception("Unexpected error saving current states")
@callback
def async_setup_dump(self, *args: Any) -> None:
@@ -258,13 +269,15 @@ class RestoreStateData:
@callback
def async_restore_entity_removed(
self, entity_id: str, extra_data: ExtraStoredData | None
self,
entity_id: str,
state: State | None,
extra_data: ExtraStoredData | None,
) -> None:
"""Unregister this entity from saving state."""
# When an entity is being removed from hass, store its last state. This
# allows us to support state restoration if the entity is removed, then
# re-added while hass is still running.
state = self.hass.states.get(entity_id)
# To fully mimic all the attribute data types when loaded from storage,
# we're going to serialize it to JSON and then re-load it.
if state is not None:
@@ -287,8 +300,18 @@ class RestoreEntity(Entity):
async def async_internal_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
try:
extra_data = self.extra_restore_state_data
except Exception:
_LOGGER.exception(
"Error getting extra restore state data for %s", self.entity_id
)
state = None
extra_data = None
else:
state = self.hass.states.get(self.entity_id)
async_get(self.hass).async_restore_entity_removed(
self.entity_id, self.extra_restore_state_data
self.entity_id, state, extra_data
)
await super().async_internal_will_remove_from_hass()

View File

@@ -6,6 +6,8 @@ import logging
from typing import Any
from unittest.mock import Mock, patch
import pytest
from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import CoreState, HomeAssistant, State
from homeassistant.exceptions import HomeAssistantError
@@ -16,6 +18,7 @@ from homeassistant.helpers.reload import async_get_platform_without_config_entry
from homeassistant.helpers.restore_state import (
DATA_RESTORE_STATE,
STORAGE_KEY,
ExtraStoredData,
RestoreEntity,
RestoreStateData,
StoredState,
@@ -342,8 +345,12 @@ async def test_dump_data(hass: HomeAssistant) -> None:
assert state1["state"]["state"] == "off"
async def test_dump_error(hass: HomeAssistant) -> None:
"""Test that we cache data."""
@pytest.mark.parametrize(
"exception",
[HomeAssistantError, RuntimeError],
)
async def test_dump_error(hass: HomeAssistant, exception: type[Exception]) -> None:
"""Test that errors during save are caught."""
states = [
State("input_boolean.b0", "on"),
State("input_boolean.b1", "on"),
@@ -368,7 +375,7 @@ async def test_dump_error(hass: HomeAssistant) -> None:
with patch(
"homeassistant.helpers.restore_state.Store.async_save",
side_effect=HomeAssistantError,
side_effect=exception,
) as mock_write_data:
await data.async_dump_states()
@@ -534,3 +541,89 @@ async def test_restore_entity_end_to_end(
assert len(storage_data) == 1
assert storage_data[0]["state"]["entity_id"] == entity_id
assert storage_data[0]["state"]["state"] == "stored"
async def test_dump_states_with_failing_extra_data(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that a failing extra_restore_state_data skips only that entity."""
class BadRestoreEntity(RestoreEntity):
"""Entity that raises on extra_restore_state_data."""
@property
def extra_restore_state_data(self) -> ExtraStoredData | None:
raise RuntimeError("Unexpected error")
states = [
State("input_boolean.good", "on"),
State("input_boolean.bad", "on"),
]
platform = MockEntityPlatform(hass, domain="input_boolean")
good_entity = RestoreEntity()
good_entity.hass = hass
good_entity.entity_id = "input_boolean.good"
await platform.async_add_entities([good_entity])
bad_entity = BadRestoreEntity()
bad_entity.hass = hass
bad_entity.entity_id = "input_boolean.bad"
await platform.async_add_entities([bad_entity])
for state in states:
hass.states.async_set(state.entity_id, state.state, state.attributes)
data = async_get(hass)
with patch(
"homeassistant.helpers.restore_state.Store.async_save"
) as mock_write_data:
await data.async_dump_states()
assert mock_write_data.called
written_states = mock_write_data.mock_calls[0][1][0]
# Only the good entity should be saved
assert len(written_states) == 1
state0 = json_round_trip(written_states[0])
assert state0["state"]["entity_id"] == "input_boolean.good"
assert state0["state"]["state"] == "on"
assert "Error getting extra restore state data for input_boolean.bad" in caplog.text
async def test_entity_removal_with_failing_extra_data(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that entity removal succeeds even if extra_restore_state_data raises."""
class BadRestoreEntity(RestoreEntity):
"""Entity that raises on extra_restore_state_data."""
@property
def extra_restore_state_data(self) -> ExtraStoredData | None:
raise RuntimeError("Unexpected error")
platform = MockEntityPlatform(hass, domain="input_boolean")
entity = BadRestoreEntity()
entity.hass = hass
entity.entity_id = "input_boolean.bad"
await platform.async_add_entities([entity])
hass.states.async_set("input_boolean.bad", "on")
data = async_get(hass)
assert "input_boolean.bad" in data.entities
await entity.async_remove()
# Entity should be unregistered
assert "input_boolean.bad" not in data.entities
# No last state should be saved since extra data failed
assert "input_boolean.bad" not in data.last_states
assert "Error getting extra restore state data for input_boolean.bad" in caplog.text