diff --git a/homeassistant/components/hassio/__init__.py b/homeassistant/components/hassio/__init__.py index 0c70e83cb79..ae72546a10d 100644 --- a/homeassistant/components/hassio/__init__.py +++ b/homeassistant/components/hassio/__init__.py @@ -620,7 +620,11 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) - # Pop add-on data + # Unload coordinator + coordinator: HassioDataUpdateCoordinator = hass.data[ADDONS_COORDINATOR] + coordinator.unload() + + # Pop coordinator hass.data.pop(ADDONS_COORDINATOR, None) return unload_ok diff --git a/homeassistant/components/hassio/coordinator.py b/homeassistant/components/hassio/coordinator.py index a49c7b88580..5277f1fa496 100644 --- a/homeassistant/components/hassio/coordinator.py +++ b/homeassistant/components/hassio/coordinator.py @@ -563,3 +563,8 @@ class HassioDataUpdateCoordinator(DataUpdateCoordinator): self.async_set_updated_data(data) except SupervisorError as err: _LOGGER.warning("Could not refresh info for %s: %s", addon_slug, err) + + @callback + def unload(self) -> None: + """Clean up when config entry unloaded.""" + self.jobs.unload() diff --git a/homeassistant/components/hassio/jobs.py b/homeassistant/components/hassio/jobs.py index a3cfb695296..4bd969ecc6e 100644 --- a/homeassistant/components/hassio/jobs.py +++ b/homeassistant/components/hassio/jobs.py @@ -3,6 +3,7 @@ from collections.abc import Callable from dataclasses import dataclass, replace from functools import partial +import logging from typing import Any from uuid import UUID @@ -29,6 +30,8 @@ from .const import ( ) from .handler import get_supervisor_client +_LOGGER = logging.getLogger(__name__) + @dataclass(slots=True, frozen=True) class JobSubscription: @@ -45,7 +48,7 @@ class JobSubscription: event_callback: Callable[[Job], Any] uuid: str | None = None name: str | None = None - reference: str | None | type[Any] = Any + reference: str | None = None def __post_init__(self) -> None: """Validate at least one filter option is present.""" @@ -58,7 +61,7 @@ class JobSubscription: """Return true if job matches subscription filters.""" if self.uuid: return job.uuid == self.uuid - return job.name == self.name and self.reference in (Any, job.reference) + return job.name == self.name and self.reference in (None, job.reference) class SupervisorJobs: @@ -70,6 +73,7 @@ class SupervisorJobs: self._supervisor_client = get_supervisor_client(hass) self._jobs: dict[UUID, Job] = {} self._subscriptions: set[JobSubscription] = set() + self._dispatcher_disconnect: Callable[[], None] | None = None @property def current_jobs(self) -> list[Job]: @@ -79,20 +83,24 @@ class SupervisorJobs: def subscribe(self, subscription: JobSubscription) -> CALLBACK_TYPE: """Subscribe to updates for job. Return callback is used to unsubscribe. - If any jobs match the subscription at the time this is called, creates - tasks to run their callback on it. + If any jobs match the subscription at the time this is called, runs the + callback on them. """ self._subscriptions.add(subscription) - # As these are callbacks they are safe to run in the event loop - # We wrap these in an asyncio task so subscribing does not wait on the logic - if matches := [job for job in self._jobs.values() if subscription.matches(job)]: - - async def event_callback_async(job: Job) -> Any: - return subscription.event_callback(job) - - for match in matches: - self._hass.async_create_task(event_callback_async(match)) + # Run the callback on each existing match + # We catch all errors to prevent an error in one from stopping the others + for match in [job for job in self._jobs.values() if subscription.matches(job)]: + try: + return subscription.event_callback(match) + except Exception as err: # noqa: BLE001 + _LOGGER.error( + "Error encountered processing Supervisor Job (%s %s %s) - %s", + match.name, + match.reference, + match.uuid, + err, + ) return partial(self._subscriptions.discard, subscription) @@ -131,7 +139,7 @@ class SupervisorJobs: # If this is the first update register to receive Supervisor events if first_update: - async_dispatcher_connect( + self._dispatcher_disconnect = async_dispatcher_connect( self._hass, EVENT_SUPERVISOR_EVENT, self._supervisor_events_to_jobs ) @@ -158,3 +166,14 @@ class SupervisorJobs: for sub in self._subscriptions: if sub.matches(job): sub.event_callback(job) + + # If the job is done, pop it from our cache if present after processing is done + if job.done and job.uuid in self._jobs: + del self._jobs[job.uuid] + + @callback + def unload(self) -> None: + """Unregister with dispatcher on config entry unload.""" + if self._dispatcher_disconnect: + self._dispatcher_disconnect() + self._dispatcher_disconnect = None diff --git a/tests/components/hassio/test_jobs.py b/tests/components/hassio/test_jobs.py new file mode 100644 index 00000000000..ea93f3002f1 --- /dev/null +++ b/tests/components/hassio/test_jobs.py @@ -0,0 +1,345 @@ +"""Test supervisor jobs manager.""" + +from collections.abc import Generator +from datetime import datetime +import os +from unittest.mock import AsyncMock, patch +from uuid import uuid4 + +from aiohasupervisor.models import Job, JobsInfo +import pytest + +from homeassistant.components.hassio.const import ADDONS_COORDINATOR +from homeassistant.components.hassio.coordinator import HassioDataUpdateCoordinator +from homeassistant.components.hassio.jobs import JobSubscription +from homeassistant.core import HomeAssistant, callback +from homeassistant.setup import async_setup_component + +from .test_init import MOCK_ENVIRON + +from tests.typing import WebSocketGenerator + + +@pytest.fixture(autouse=True) +def fixture_supervisor_environ() -> Generator[None]: + """Mock os environ for supervisor.""" + with patch.dict(os.environ, MOCK_ENVIRON): + yield + + +@pytest.mark.usefixtures("all_setup_requests") +async def test_job_manager_setup(hass: HomeAssistant, jobs_info: AsyncMock) -> None: + """Test setup of job manager.""" + jobs_info.return_value = JobsInfo( + ignore_conditions=[], + jobs=[ + Job( + name="test_job", + reference=None, + uuid=uuid4(), + progress=0, + stage=None, + done=False, + errors=[], + created=datetime.now(), + extra=None, + child_jobs=[ + Job( + name="test_inner_job", + reference=None, + uuid=uuid4(), + progress=0, + stage=None, + done=False, + errors=[], + created=datetime.now(), + extra=None, + child_jobs=[], + ) + ], + ) + ], + ) + + result = await async_setup_component(hass, "hassio", {}) + assert result + jobs_info.assert_called_once() + + data_coordinator: HassioDataUpdateCoordinator = hass.data[ADDONS_COORDINATOR] + assert len(data_coordinator.jobs.current_jobs) == 2 + assert data_coordinator.jobs.current_jobs[0].name == "test_job" + assert data_coordinator.jobs.current_jobs[1].name == "test_inner_job" + + +@pytest.mark.usefixtures("all_setup_requests") +async def test_disconnect_on_config_entry_reload( + hass: HomeAssistant, jobs_info: AsyncMock +) -> None: + """Test dispatcher subscription disconnects on config entry reload.""" + result = await async_setup_component(hass, "hassio", {}) + assert result + jobs_info.assert_called_once() + + jobs_info.reset_mock() + data_coordinator: HassioDataUpdateCoordinator = hass.data[ADDONS_COORDINATOR] + await hass.config_entries.async_reload(data_coordinator.entry_id) + await hass.async_block_till_done() + jobs_info.assert_called_once() + + +@pytest.mark.usefixtures("all_setup_requests") +async def test_job_manager_ws_updates( + hass: HomeAssistant, jobs_info: AsyncMock, hass_ws_client: WebSocketGenerator +) -> None: + """Test job updates sync from Supervisor WS messages.""" + result = await async_setup_component(hass, "hassio", {}) + assert result + jobs_info.assert_called_once() + + jobs_info.reset_mock() + client = await hass_ws_client(hass) + data_coordinator: HassioDataUpdateCoordinator = hass.data[ADDONS_COORDINATOR] + assert not data_coordinator.jobs.current_jobs + + # Make an example listener + job_data: Job | None = None + + @callback + def mock_subcription_callback(job: Job) -> None: + nonlocal job_data + job_data = job + + subscription = JobSubscription( + mock_subcription_callback, name="test_job", reference="test" + ) + unsubscribe = data_coordinator.jobs.subscribe(subscription) + + # Send start of job update + await client.send_json( + { + "id": 1, + "type": "supervisor/event", + "data": { + "event": "job", + "data": { + "name": "test_job", + "reference": "test", + "uuid": (uuid := uuid4().hex), + "progress": 0, + "stage": None, + "done": False, + "errors": [], + "created": (created := datetime.now().isoformat()), + "extra": None, + }, + }, + } + ) + msg = await client.receive_json() + assert msg["success"] + await hass.async_block_till_done() + + assert job_data.name == "test_job" + assert job_data.reference == "test" + assert job_data.progress == 0 + assert job_data.done is False + # One job in the cache + assert len(data_coordinator.jobs.current_jobs) == 1 + + # Example progress update + await client.send_json( + { + "id": 2, + "type": "supervisor/event", + "data": { + "event": "job", + "data": { + "name": "test_job", + "reference": "test", + "uuid": uuid, + "progress": 50, + "stage": None, + "done": False, + "errors": [], + "created": created, + "extra": None, + }, + }, + } + ) + msg = await client.receive_json() + assert msg["success"] + await hass.async_block_till_done() + + assert job_data.name == "test_job" + assert job_data.reference == "test" + assert job_data.progress == 50 + assert job_data.done is False + # Same job, same number of jobs in cache + assert len(data_coordinator.jobs.current_jobs) == 1 + + # Unrelated job update - name change, subscriber should not receive + await client.send_json( + { + "id": 3, + "type": "supervisor/event", + "data": { + "event": "job", + "data": { + "name": "bad_job", + "reference": "test", + "uuid": uuid4().hex, + "progress": 0, + "stage": None, + "done": False, + "errors": [], + "created": created, + "extra": None, + }, + }, + } + ) + msg = await client.receive_json() + assert msg["success"] + await hass.async_block_till_done() + + assert job_data.name == "test_job" + assert job_data.reference == "test" + # New job, cache increases + assert len(data_coordinator.jobs.current_jobs) == 2 + + # Unrelated job update - reference change, subscriber should not receive + await client.send_json( + { + "id": 4, + "type": "supervisor/event", + "data": { + "event": "job", + "data": { + "name": "test_job", + "reference": "bad", + "uuid": uuid4().hex, + "progress": 0, + "stage": None, + "done": False, + "errors": [], + "created": created, + "extra": None, + }, + }, + } + ) + msg = await client.receive_json() + assert msg["success"] + await hass.async_block_till_done() + + assert job_data.name == "test_job" + assert job_data.reference == "test" + # New job, cache increases + assert len(data_coordinator.jobs.current_jobs) == 3 + + # Unsubscribe mock listener, should not receive final update + unsubscribe() + await client.send_json( + { + "id": 5, + "type": "supervisor/event", + "data": { + "event": "job", + "data": { + "name": "test_job", + "reference": "test", + "uuid": uuid, + "progress": 100, + "stage": None, + "done": True, + "errors": [], + "created": created, + "extra": None, + }, + }, + } + ) + msg = await client.receive_json() + assert msg["success"] + await hass.async_block_till_done() + + assert job_data.name == "test_job" + assert job_data.reference == "test" + assert job_data.progress == 50 + assert job_data.done is False + # Job ended, cache decreases + assert len(data_coordinator.jobs.current_jobs) == 2 + + # REST API should not be used during this sequence + jobs_info.assert_not_called() + + +@pytest.mark.usefixtures("all_setup_requests") +async def test_job_manager_reload_on_supervisor_restart( + hass: HomeAssistant, jobs_info: AsyncMock, hass_ws_client: WebSocketGenerator +) -> None: + """Test job manager reloads cache on supervisor restart.""" + jobs_info.return_value = JobsInfo( + ignore_conditions=[], + jobs=[ + Job( + name="test_job", + reference="test", + uuid=uuid4(), + progress=0, + stage=None, + done=False, + errors=[], + created=datetime.now(), + extra=None, + child_jobs=[], + ) + ], + ) + + result = await async_setup_component(hass, "hassio", {}) + assert result + jobs_info.assert_called_once() + + data_coordinator: HassioDataUpdateCoordinator = hass.data[ADDONS_COORDINATOR] + assert len(data_coordinator.jobs.current_jobs) == 1 + assert data_coordinator.jobs.current_jobs[0].name == "test_job" + + jobs_info.reset_mock() + jobs_info.return_value = JobsInfo(ignore_conditions=[], jobs=[]) + client = await hass_ws_client(hass) + + # Make an example listener + job_data: Job | None = None + + @callback + def mock_subcription_callback(job: Job) -> None: + nonlocal job_data + job_data = job + + subscription = JobSubscription(mock_subcription_callback, name="test_job") + data_coordinator.jobs.subscribe(subscription) + + # Send supervisor restart signal + await client.send_json( + { + "id": 1, + "type": "supervisor/event", + "data": { + "event": "supervisor_update", + "update_key": "supervisor", + "data": {"startup": "complete"}, + }, + } + ) + msg = await client.receive_json() + assert msg["success"] + await hass.async_block_till_done() + + # Listener should be told job is done and cache cleared out + jobs_info.assert_called_once() + assert job_data.name == "test_job" + assert job_data.reference == "test" + assert job_data.done is True + assert not data_coordinator.jobs.current_jobs