1
0
mirror of https://github.com/home-assistant/core.git synced 2026-02-15 07:36:16 +00:00

Respect callback decorator in store helper async_delay_save (#157158)

This commit is contained in:
Erik Montnemery
2025-11-25 16:08:09 +01:00
committed by GitHub
parent 0c366506c5
commit 4c04dc00dd
5 changed files with 136 additions and 28 deletions

View File

@@ -160,15 +160,20 @@ def _orjson_bytes_default_encoder(data: Any) -> bytes:
)
def save_json(
filename: str,
def prepare_save_json(
data: list | dict,
private: bool = False,
*,
encoder: type[json.JSONEncoder] | None = None,
atomic_writes: bool = False,
) -> None:
"""Save JSON data to a file."""
) -> tuple[str, str | bytes]:
"""Prepare JSON data for saving to a file.
Returns a tuple of (mode, json_data) where mode is either 'w' or 'wb'
and json_data is either a str or bytes depending on the mode.
Args:
data: Data to serialize.
encoder: Optional custom JSON encoder.
"""
dump: Callable[[Any], Any]
try:
# For backwards compatibility, if they pass in the
@@ -188,10 +193,24 @@ def save_json(
formatted_data = format_unserializable_data(
find_paths_unserializable_data(data, dump=dump)
)
msg = f"Failed to serialize to JSON: {filename}. Bad data at {formatted_data}"
_LOGGER.error(msg)
raise SerializationError(msg) from error
raise SerializationError(f"Bad data at {formatted_data}") from error
return (mode, json_data)
def save_json(
filename: str,
data: list | dict,
private: bool = False,
*,
encoder: type[json.JSONEncoder] | None = None,
atomic_writes: bool = False,
) -> None:
"""Save JSON data to a file."""
try:
mode, json_data = prepare_save_json(data, encoder=encoder)
except SerializationError as err:
_LOGGER.error("Failed to serialize to JSON: %s. %s", filename, err)
raise
method = write_utf8_file_atomic if atomic_writes else write_utf8_file
method(filename, json_data, private, mode=mode)

View File

@@ -27,11 +27,12 @@ from homeassistant.core import (
Event,
HomeAssistant,
callback,
is_callback,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util, json as json_util
from homeassistant.util.file import WriteError
from homeassistant.util.file import WriteError, write_utf8_file, write_utf8_file_atomic
from homeassistant.util.hass_dict import HassKey
from . import json as json_helper
@@ -441,7 +442,12 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
data_func: Callable[[], _T],
delay: float = 0,
) -> None:
"""Save data with an optional delay."""
"""Save data with an optional delay.
data_func: A function that returns the data to save. If the function
is decorated with @callback, it will be called in the event loop. If
it is a regular function, it will be called from an executor.
"""
self._data = {
"version": self.version,
"minor_version": self.minor_version,
@@ -537,28 +543,37 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
return
try:
await self._async_write_data(self.path, data)
await self._async_write_data(data)
except (json_util.SerializationError, WriteError) as err:
_LOGGER.error("Error writing config for %s: %s", self.key, err)
async def _async_write_data(self, path: str, data: dict) -> None:
await self.hass.async_add_executor_job(self._write_data, self.path, data)
async def _async_write_data(self, data: dict) -> None:
if "data_func" in data and is_callback(data["data_func"]):
data["data"] = data.pop("data_func")()
mode, json_data = json_helper.prepare_save_json(data, encoder=self._encoder)
await self.hass.async_add_executor_job(
self._write_prepared_data, mode, json_data
)
return
await self.hass.async_add_executor_job(self._write_data, data)
def _write_data(self, path: str, data: dict) -> None:
def _write_data(self, data: dict) -> None:
"""Write the data."""
os.makedirs(os.path.dirname(path), exist_ok=True)
if "data_func" in data:
data["data"] = data.pop("data_func")()
mode, json_data = json_helper.prepare_save_json(data, encoder=self._encoder)
self._write_prepared_data(mode, json_data)
def _write_prepared_data(self, mode: str, json_data: str | bytes) -> None:
"""Write the data."""
path = self.path
os.makedirs(os.path.dirname(path), exist_ok=True)
_LOGGER.debug("Writing data for %s to %s", self.key, path)
json_helper.save_json(
path,
data,
self._private,
encoder=self._encoder,
atomic_writes=self._atomic_writes,
write_method = (
write_utf8_file_atomic if self._atomic_writes else write_utf8_file
)
write_method(path, json_data, self._private, mode=mode)
async def _async_migrate_func(self, old_major_version, old_minor_version, old_data):
"""Migrate to the new version."""

View File

@@ -1531,7 +1531,7 @@ def mock_storage(data: dict[str, Any] | None = None) -> Generator[dict[str, Any]
return loaded
async def mock_write_data(
store: storage.Store, path: str, data_to_write: dict[str, Any]
store: storage.Store, data_to_write: dict[str, Any]
) -> None:
"""Mock version of write data."""
# To ensure that the data can be serialized

View File

@@ -237,9 +237,7 @@ def test_save_bad_data() -> None:
with pytest.raises(SerializationError) as excinfo:
save_json("test4", {"hello": CannotSerializeMe()})
assert "Failed to serialize to JSON: test4. Bad data at $.hello=" in str(
excinfo.value
)
assert "Bad data at $.hello=" in str(excinfo.value)
def test_custom_encoder(tmp_path: Path) -> None:

View File

@@ -4,6 +4,8 @@ import asyncio
from datetime import timedelta
import json
import os
from pathlib import Path
import threading
from typing import Any, NamedTuple
from unittest.mock import Mock, patch
@@ -17,7 +19,12 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_STARTED,
EVENT_HOMEASSISTANT_STOP,
)
from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, CoreState, HomeAssistant
from homeassistant.core import (
DOMAIN as HOMEASSISTANT_DOMAIN,
CoreState,
HomeAssistant,
callback,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import issue_registry as ir, storage
from homeassistant.helpers.json import json_bytes
@@ -35,6 +42,7 @@ MOCK_VERSION_2 = 2
MOCK_MINOR_VERSION_1 = 1
MOCK_MINOR_VERSION_2 = 2
MOCK_KEY = "storage-test"
MOCK_KEY2 = "storage-test-2"
MOCK_DATA = {"hello": "world"}
MOCK_DATA2 = {"goodbye": "cruel world"}
@@ -140,6 +148,74 @@ async def test_saving_with_delay(
}
async def test_saving_with_delay_threading(tmp_path: Path) -> None:
"""Test thread handling when saving with a delay."""
calls = []
async def assert_storage_data(store_key: str, expected_data: str) -> None:
"""Assert storage data."""
def read_storage_data(store_key: str) -> str:
"""Read storage data."""
return Path(tmp_path / f".storage/{store_key}").read_text(encoding="utf-8")
store_data = await asyncio.to_thread(read_storage_data, store_key)
assert store_data == expected_data
async with async_test_home_assistant(config_dir=tmp_path) as hass:
def data_producer_thread_safe() -> Any:
"""Produce data to store."""
assert threading.get_ident() != hass.loop_thread_id
calls.append("thread_safe")
return MOCK_DATA
@callback
def data_producer_callback() -> Any:
"""Produce data to store."""
assert threading.get_ident() == hass.loop_thread_id
calls.append("callback")
return MOCK_DATA2
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
store.async_delay_save(data_producer_thread_safe, 1)
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=1))
await hass.async_block_till_done()
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY2)
store.async_delay_save(data_producer_callback, 1)
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=1))
await hass.async_block_till_done()
assert calls == ["thread_safe", "callback"]
expected_data = (
"{\n"
' "version": 1,\n'
' "minor_version": 1,\n'
' "key": "storage-test",\n'
' "data": {\n'
' "hello": "world"\n'
" }\n"
"}"
)
await assert_storage_data(MOCK_KEY, expected_data)
expected_data = (
"{\n"
' "version": 1,\n'
' "minor_version": 1,\n'
' "key": "storage-test-2",\n'
' "data": {\n'
' "goodbye": "cruel world"\n'
" }\n"
"}"
)
await assert_storage_data(MOCK_KEY2, expected_data)
await hass.async_stop(force=True)
async def test_saving_with_delay_churn_reduction(
hass: HomeAssistant,
store: storage.Store,