diff --git a/homeassistant/components/arcam_fmj/__init__.py b/homeassistant/components/arcam_fmj/__init__.py index df088738a64..f389bc55a2b 100644 --- a/homeassistant/components/arcam_fmj/__init__.py +++ b/homeassistant/components/arcam_fmj/__init__.py @@ -2,8 +2,8 @@ import asyncio from asyncio import timeout +from contextlib import AsyncExitStack import logging -from typing import Any from arcam.fmj import ConnectionFailed from arcam.fmj.client import Client @@ -54,36 +54,31 @@ async def _run_client( client = runtime_data.client coordinators = runtime_data.coordinators - def _listen(_: Any) -> None: - for coordinator in coordinators.values(): - coordinator.async_notify_data_updated() - while True: try: - async with timeout(interval): - await client.start() + async with AsyncExitStack() as stack: + async with timeout(interval): + await client.start() + stack.push_async_callback(client.stop) - _LOGGER.debug("Client connected %s", client.host) + _LOGGER.debug("Client connected %s", client.host) - try: - for coordinator in coordinators.values(): - await coordinator.state.start() - - with client.listen(_listen): + try: for coordinator in coordinators.values(): - coordinator.async_notify_connected() - await client.process() - finally: - await client.stop() + await stack.enter_async_context( + coordinator.async_monitor_client() + ) - _LOGGER.debug("Client disconnected %s", client.host) - for coordinator in coordinators.values(): - coordinator.async_notify_disconnected() + await client.process() + finally: + _LOGGER.debug("Client disconnected %s", client.host) except ConnectionFailed: - await asyncio.sleep(interval) + pass except TimeoutError: continue except Exception: _LOGGER.exception("Unexpected exception, aborting arcam client") return + + await asyncio.sleep(interval) diff --git a/homeassistant/components/arcam_fmj/coordinator.py b/homeassistant/components/arcam_fmj/coordinator.py index 83faef37d10..39b3f28fc68 100644 --- a/homeassistant/components/arcam_fmj/coordinator.py +++ b/homeassistant/components/arcam_fmj/coordinator.py @@ -2,11 +2,13 @@ from __future__ import annotations +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from dataclasses import dataclass import logging from arcam.fmj import ConnectionFailed -from arcam.fmj.client import Client +from arcam.fmj.client import AmxDuetResponse, Client, ResponsePacket from arcam.fmj.state import State from homeassistant.config_entries import ConfigEntry @@ -51,7 +53,7 @@ class ArcamFmjCoordinator(DataUpdateCoordinator[None]): ) self.client = client self.state = State(client, zone) - self.last_update_success = False + self.update_in_progress = False name = config_entry.title unique_id = config_entry.unique_id or config_entry.entry_id @@ -74,24 +76,34 @@ class ArcamFmjCoordinator(DataUpdateCoordinator[None]): async def _async_update_data(self) -> None: """Fetch data for manual refresh.""" try: + self.update_in_progress = True await self.state.update() except ConnectionFailed as err: raise UpdateFailed( f"Connection failed during update for zone {self.state.zn}" ) from err + finally: + self.update_in_progress = False @callback - def async_notify_data_updated(self) -> None: - """Notify that new data has been received from the device.""" - self.async_set_updated_data(None) + def _async_notify_packet(self, packet: ResponsePacket | AmxDuetResponse) -> None: + """Packet callback to detect changes to state.""" + if ( + not isinstance(packet, ResponsePacket) + or packet.zn != self.state.zn + or self.update_in_progress + ): + return - @callback - def async_notify_connected(self) -> None: - """Handle client connected.""" - self.hass.async_create_task(self.async_refresh()) - - @callback - def async_notify_disconnected(self) -> None: - """Handle client disconnected.""" - self.last_update_success = False self.async_update_listeners() + + @asynccontextmanager + async def async_monitor_client(self) -> AsyncGenerator[None]: + """Monitor a client and state for changes while connected.""" + async with self.state: + self.hass.async_create_task(self.async_refresh()) + try: + with self.client.listen(self._async_notify_packet): + yield + finally: + self.hass.async_create_task(self.async_refresh()) diff --git a/homeassistant/components/arcam_fmj/entity.py b/homeassistant/components/arcam_fmj/entity.py index 6d635a5f1c5..cf97ef32c38 100644 --- a/homeassistant/components/arcam_fmj/entity.py +++ b/homeassistant/components/arcam_fmj/entity.py @@ -26,3 +26,8 @@ class ArcamFmjEntity(CoordinatorEntity[ArcamFmjCoordinator]): if description is not None: self._attr_unique_id = f"{self._attr_unique_id}-{description.key}" self.entity_description = description + + @property + def available(self) -> bool: + """Return if entity is available.""" + return super().available and self.coordinator.client.connected diff --git a/tests/components/arcam_fmj/conftest.py b/tests/components/arcam_fmj/conftest.py index f11a1c3002f..1fc6e6b607e 100644 --- a/tests/components/arcam_fmj/conftest.py +++ b/tests/components/arcam_fmj/conftest.py @@ -1,15 +1,17 @@ """Tests for the arcam_fmj component.""" -from collections.abc import AsyncGenerator -from unittest.mock import Mock, patch +from asyncio import CancelledError, Queue +from collections.abc import AsyncGenerator, Generator +from contextlib import contextmanager +from unittest.mock import AsyncMock, Mock, patch -from arcam.fmj.client import Client +from arcam.fmj.client import Client, ResponsePacket from arcam.fmj.state import State import pytest from homeassistant.components.arcam_fmj.const import DEFAULT_NAME from homeassistant.const import CONF_HOST, CONF_PORT -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry @@ -28,12 +30,50 @@ MOCK_CONFIG_ENTRY = {CONF_HOST: MOCK_HOST, CONF_PORT: MOCK_PORT} @pytest.fixture(name="client") -def client_fixture() -> Mock: +def client_fixture() -> Generator[Mock]: """Get a mocked client.""" client = Mock(Client) client.host = MOCK_HOST client.port = MOCK_PORT - return client + + queue = Queue[BaseException | None]() + listeners = set() + + async def _start(): + client.connected = True + + async def _process(): + result = await queue.get() + client.connected = False + if isinstance(result, BaseException): + raise result + + @contextmanager + def _listen(listener): + listeners.add(listener) + yield client + listeners.remove(listener) + + @callback + def _notify_data_updated(zn=1): + packet = Mock(ResponsePacket) + packet.zn = zn + for listener in listeners: + listener(packet) + + @callback + def _notify_connection(exception: Exception | None = None): + queue.put_nowait(exception) + + client.start.side_effect = _start + client.process.side_effect = _process + client.listen.side_effect = _listen + client.notify_data_updated = _notify_data_updated + client.notify_connection = _notify_connection + + yield client + + queue.put_nowait(CancelledError()) @pytest.fixture(name="state_1") @@ -52,6 +92,8 @@ def state_1_fixture(client: Mock) -> State: state.get_mute.return_value = None state.get_decode_modes.return_value = [] state.get_decode_mode.return_value = None + state.__aenter__ = AsyncMock() + state.__aexit__ = AsyncMock() return state @@ -71,6 +113,8 @@ def state_2_fixture(client: Mock) -> State: state.get_mute.return_value = None state.get_decode_modes.return_value = [] state.get_decode_mode.return_value = None + state.__aenter__ = AsyncMock() + state.__aexit__ = AsyncMock() return state @@ -104,18 +148,6 @@ async def player_setup_fixture( return state_2 raise ValueError(f"Unknown player zone: {zone}") - async def _mock_run_client(hass: HomeAssistant, runtime_data, interval): - coordinators = runtime_data.coordinators - - def _notify_data_updated() -> None: - for coordinator in coordinators.values(): - coordinator.async_notify_data_updated() - - client.notify_data_updated = _notify_data_updated - - for coordinator in coordinators.values(): - coordinator.async_notify_connected() - await async_setup_component(hass, "homeassistant", {}) with ( @@ -124,10 +156,6 @@ async def player_setup_fixture( "homeassistant.components.arcam_fmj.coordinator.State", side_effect=state_mock, ), - patch( - "homeassistant.components.arcam_fmj._run_client", - side_effect=_mock_run_client, - ), ): assert await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() diff --git a/tests/components/arcam_fmj/test_media_player.py b/tests/components/arcam_fmj/test_media_player.py index b1a7468fb46..d14fb8fc2f6 100644 --- a/tests/components/arcam_fmj/test_media_player.py +++ b/tests/components/arcam_fmj/test_media_player.py @@ -33,7 +33,7 @@ from homeassistant.components.media_player import ( SERVICE_VOLUME_UP, MediaType, ) -from homeassistant.const import ATTR_ENTITY_ID, Platform +from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE, Platform from homeassistant.core import HomeAssistant, State as CoreState from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import entity_registry as er @@ -62,6 +62,21 @@ async def test_setup( await snapshot_platform(hass, entity_registry, snapshot, mock_config_entry.entry_id) +@pytest.mark.usefixtures("player_setup") +async def test_disconnect(hass: HomeAssistant, client: Mock) -> None: + """Test a disconnection is detected.""" + data = hass.states.get(MOCK_ENTITY_ID) + assert data + assert data.state != STATE_UNAVAILABLE + + client.notify_connection(ConnectionFailed()) + await hass.async_block_till_done() + + data = hass.states.get(MOCK_ENTITY_ID) + assert data + assert data.state == STATE_UNAVAILABLE + + async def update(hass: HomeAssistant, client: Mock, entity_id: str) -> CoreState: """Force a update of player and return current state data.""" client.notify_data_updated()