diff --git a/homeassistant/components/frontend/storage.py b/homeassistant/components/frontend/storage.py index 11d155dbcb4..aa1ce27e3cd 100644 --- a/homeassistant/components/frontend/storage.py +++ b/homeassistant/components/frontend/storage.py @@ -11,11 +11,14 @@ import voluptuous as vol from homeassistant.components import websocket_api from homeassistant.components.websocket_api import ActiveConnection from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import singleton from homeassistant.helpers.storage import Store from homeassistant.util.hass_dict import HassKey DATA_STORAGE: HassKey[dict[str, UserStore]] = HassKey("frontend_storage") +DATA_SYSTEM_STORAGE: HassKey[SystemStore] = HassKey("frontend_system_storage") STORAGE_VERSION_USER_DATA = 1 +STORAGE_VERSION_SYSTEM_DATA = 1 async def async_setup_frontend_storage(hass: HomeAssistant) -> None: @@ -23,6 +26,9 @@ async def async_setup_frontend_storage(hass: HomeAssistant) -> None: websocket_api.async_register_command(hass, websocket_set_user_data) websocket_api.async_register_command(hass, websocket_get_user_data) websocket_api.async_register_command(hass, websocket_subscribe_user_data) + websocket_api.async_register_command(hass, websocket_set_system_data) + websocket_api.async_register_command(hass, websocket_get_system_data) + websocket_api.async_register_command(hass, websocket_subscribe_system_data) async def async_user_store(hass: HomeAssistant, user_id: str) -> UserStore: @@ -83,6 +89,52 @@ class _UserStore(Store[dict[str, Any]]): ) +@singleton.singleton(DATA_SYSTEM_STORAGE, async_=True) +async def async_system_store(hass: HomeAssistant) -> SystemStore: + """Access the system store.""" + store = SystemStore(hass) + await store.async_load() + return store + + +class SystemStore: + """System store for frontend data.""" + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the system store.""" + self._store: Store[dict[str, Any]] = Store( + hass, + STORAGE_VERSION_SYSTEM_DATA, + "frontend.system_data", + ) + self.data: dict[str, Any] = {} + self.subscriptions: dict[str, list[Callable[[], None]]] = {} + + async def async_load(self) -> None: + """Load the data from the store.""" + self.data = await self._store.async_load() or {} + + async def async_set_item(self, key: str, value: Any) -> None: + """Set an item and save the store.""" + self.data[key] = value + self._store.async_delay_save(lambda: self.data, 1.0) + for cb in self.subscriptions.get(key, []): + cb() + + @callback + def async_subscribe( + self, key: str, on_update_callback: Callable[[], None] + ) -> Callable[[], None]: + """Subscribe to store updates.""" + self.subscriptions.setdefault(key, []).append(on_update_callback) + + def unsubscribe() -> None: + """Unsubscribe from the store.""" + self.subscriptions[key].remove(on_update_callback) + + return unsubscribe + + def with_user_store( orig_func: Callable[ [HomeAssistant, ActiveConnection, dict[str, Any], UserStore], @@ -107,6 +159,28 @@ def with_user_store( return with_user_store_func +def with_system_store( + orig_func: Callable[ + [HomeAssistant, ActiveConnection, dict[str, Any], SystemStore], + Coroutine[Any, Any, None], + ], +) -> Callable[ + [HomeAssistant, ActiveConnection, dict[str, Any]], Coroutine[Any, Any, None] +]: + """Decorate function to provide system store.""" + + @wraps(orig_func) + async def with_system_store_func( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] + ) -> None: + """Provide system store to function.""" + store = await async_system_store(hass) + + await orig_func(hass, connection, msg, store) + + return with_system_store_func + + @websocket_api.websocket_command( { vol.Required("type"): "frontend/set_user_data", @@ -169,3 +243,65 @@ async def websocket_subscribe_user_data( connection.subscriptions[msg["id"]] = store.async_subscribe(key, on_data_update) on_data_update() connection.send_result(msg["id"]) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "frontend/set_system_data", + vol.Required("key"): str, + vol.Required("value"): vol.Any(bool, str, int, float, dict, list, None), + } +) +@websocket_api.require_admin +@websocket_api.async_response +@with_system_store +async def websocket_set_system_data( + hass: HomeAssistant, + connection: ActiveConnection, + msg: dict[str, Any], + store: SystemStore, +) -> None: + """Handle set system data command.""" + await store.async_set_item(msg["key"], msg["value"]) + connection.send_result(msg["id"]) + + +@websocket_api.websocket_command( + {vol.Required("type"): "frontend/get_system_data", vol.Required("key"): str} +) +@websocket_api.async_response +@with_system_store +async def websocket_get_system_data( + hass: HomeAssistant, + connection: ActiveConnection, + msg: dict[str, Any], + store: SystemStore, +) -> None: + """Handle get system data command.""" + connection.send_result(msg["id"], {"value": store.data.get(msg["key"])}) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "frontend/subscribe_system_data", + vol.Required("key"): str, + } +) +@websocket_api.async_response +@with_system_store +async def websocket_subscribe_system_data( + hass: HomeAssistant, + connection: ActiveConnection, + msg: dict[str, Any], + store: SystemStore, +) -> None: + """Handle subscribe to system data command.""" + key: str = msg["key"] + + def on_data_update() -> None: + """Handle system data update.""" + connection.send_event(msg["id"], {"value": store.data.get(key)}) + + connection.subscriptions[msg["id"]] = store.async_subscribe(key, on_data_update) + on_data_update() + connection.send_result(msg["id"]) diff --git a/tests/components/frontend/test_storage.py b/tests/components/frontend/test_storage.py index f4a61b743c5..097b00e2e9c 100644 --- a/tests/components/frontend/test_storage.py +++ b/tests/components/frontend/test_storage.py @@ -301,3 +301,274 @@ async def test_set_user_data( res = await client.receive_json() assert res["success"], res assert res["result"]["value"] == "test-value" + + +async def test_get_system_data_empty( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator +) -> None: + """Test get_system_data command.""" + client = await hass_ws_client(hass) + + await client.send_json( + {"id": 5, "type": "frontend/get_system_data", "key": "non-existing-key"} + ) + + res = await client.receive_json() + assert res["success"], res + assert res["result"]["value"] is None + + +async def test_get_system_data( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + hass_storage: dict[str, Any], +) -> None: + """Test get_system_data command.""" + storage_key = f"{DOMAIN}.system_data" + hass_storage[storage_key] = { + "key": storage_key, + "version": 1, + "data": {"test-key": "test-value", "test-complex": [{"foo": "bar"}]}, + } + + client = await hass_ws_client(hass) + + # Get a simple string key + + await client.send_json( + {"id": 6, "type": "frontend/get_system_data", "key": "test-key"} + ) + + res = await client.receive_json() + assert res["success"], res + assert res["result"]["value"] == "test-value" + + # Get a more complex key + + await client.send_json( + {"id": 7, "type": "frontend/get_system_data", "key": "test-complex"} + ) + + res = await client.receive_json() + assert res["success"], res + assert res["result"]["value"][0]["foo"] == "bar" + + +@pytest.mark.parametrize( + ("subscriptions", "events"), + [ + ([], []), + ([(1, {"key": "test-key"}, None)], [(1, "test-value")]), + ([(1, {"key": "other-key"}, None)], []), + ], +) +async def test_set_system_data_empty( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + subscriptions: list[tuple[int, dict[str, str], Any]], + events: list[tuple[int, Any]], +) -> None: + """Test set_system_data command. + + Also test subscribing. + """ + client = await hass_ws_client(hass) + + for msg_id, key, event_data in subscriptions: + await client.send_json( + { + "id": msg_id, + "type": "frontend/subscribe_system_data", + } + | key + ) + + event = await client.receive_json() + assert event == { + "id": msg_id, + "type": "event", + "event": {"value": event_data}, + } + + res = await client.receive_json() + assert res["success"], res + + # test creating + + await client.send_json( + {"id": 6, "type": "frontend/get_system_data", "key": "test-key"} + ) + + res = await client.receive_json() + assert res["success"], res + assert res["result"]["value"] is None + + await client.send_json( + { + "id": 7, + "type": "frontend/set_system_data", + "key": "test-key", + "value": "test-value", + } + ) + + for msg_id, event_data in events: + event = await client.receive_json() + assert event == {"id": msg_id, "type": "event", "event": {"value": event_data}} + + res = await client.receive_json() + assert res["success"], res + + await client.send_json( + {"id": 8, "type": "frontend/get_system_data", "key": "test-key"} + ) + + res = await client.receive_json() + assert res["success"], res + assert res["result"]["value"] == "test-value" + + +@pytest.mark.parametrize( + ("subscriptions", "events"), + [ + ( + [], + [[], []], + ), + ( + [(1, {"key": "test-key"}, "test-value")], + [[], []], + ), + ( + [(1, {"key": "test-non-existent-key"}, None)], + [[(1, "test-value-new")], []], + ), + ( + [(1, {"key": "test-complex"}, "string")], + [[], [(1, [{"foo": "bar"}])]], + ), + ( + [(1, {"key": "other-key"}, None)], + [[], []], + ), + ], +) +async def test_set_system_data( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + hass_storage: dict[str, Any], + subscriptions: list[tuple[int, dict[str, str], Any]], + events: list[list[tuple[int, Any]]], +) -> None: + """Test set_system_data command with initial data.""" + storage_key = f"{DOMAIN}.system_data" + hass_storage[storage_key] = { + "version": 1, + "data": {"test-key": "test-value", "test-complex": "string"}, + } + + client = await hass_ws_client(hass) + + for msg_id, key, event_data in subscriptions: + await client.send_json( + { + "id": msg_id, + "type": "frontend/subscribe_system_data", + } + | key + ) + + event = await client.receive_json() + assert event == { + "id": msg_id, + "type": "event", + "event": {"value": event_data}, + } + + res = await client.receive_json() + assert res["success"], res + + # test creating + + await client.send_json( + { + "id": 5, + "type": "frontend/set_system_data", + "key": "test-non-existent-key", + "value": "test-value-new", + } + ) + + for msg_id, event_data in events[0]: + event = await client.receive_json() + assert event == {"id": msg_id, "type": "event", "event": {"value": event_data}} + + res = await client.receive_json() + assert res["success"], res + + await client.send_json( + {"id": 6, "type": "frontend/get_system_data", "key": "test-non-existent-key"} + ) + + res = await client.receive_json() + assert res["success"], res + assert res["result"]["value"] == "test-value-new" + + # test updating with complex data + + await client.send_json( + { + "id": 7, + "type": "frontend/set_system_data", + "key": "test-complex", + "value": [{"foo": "bar"}], + } + ) + + for msg_id, event_data in events[1]: + event = await client.receive_json() + assert event == {"id": msg_id, "type": "event", "event": {"value": event_data}} + + res = await client.receive_json() + assert res["success"], res + + await client.send_json( + {"id": 8, "type": "frontend/get_system_data", "key": "test-complex"} + ) + + res = await client.receive_json() + assert res["success"], res + assert res["result"]["value"][0]["foo"] == "bar" + + # ensure other existing key was not modified + + await client.send_json( + {"id": 9, "type": "frontend/get_system_data", "key": "test-key"} + ) + + res = await client.receive_json() + assert res["success"], res + assert res["result"]["value"] == "test-value" + + +async def test_set_system_data_requires_admin( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + hass_read_only_access_token: str, +) -> None: + """Test set_system_data requires admin permissions.""" + client = await hass_ws_client(hass, hass_read_only_access_token) + + await client.send_json( + { + "id": 5, + "type": "frontend/set_system_data", + "key": "test-key", + "value": "test-value", + } + ) + + res = await client.receive_json() + assert not res["success"], res + assert res["error"]["code"] == "unauthorized" + assert res["error"]["message"] == "Unauthorized"