mirror of
https://github.com/home-assistant/core.git
synced 2026-04-02 00:20:30 +01:00
Rework patching and handling of client runner in arcam (#165747)
This commit is contained in:
@@ -2,8 +2,8 @@
|
||||
|
||||
import asyncio
|
||||
from asyncio import timeout
|
||||
from contextlib import AsyncExitStack
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from arcam.fmj import ConnectionFailed
|
||||
from arcam.fmj.client import Client
|
||||
@@ -54,36 +54,31 @@ async def _run_client(
|
||||
client = runtime_data.client
|
||||
coordinators = runtime_data.coordinators
|
||||
|
||||
def _listen(_: Any) -> None:
|
||||
for coordinator in coordinators.values():
|
||||
coordinator.async_notify_data_updated()
|
||||
|
||||
while True:
|
||||
try:
|
||||
async with timeout(interval):
|
||||
await client.start()
|
||||
async with AsyncExitStack() as stack:
|
||||
async with timeout(interval):
|
||||
await client.start()
|
||||
stack.push_async_callback(client.stop)
|
||||
|
||||
_LOGGER.debug("Client connected %s", client.host)
|
||||
_LOGGER.debug("Client connected %s", client.host)
|
||||
|
||||
try:
|
||||
for coordinator in coordinators.values():
|
||||
await coordinator.state.start()
|
||||
|
||||
with client.listen(_listen):
|
||||
try:
|
||||
for coordinator in coordinators.values():
|
||||
coordinator.async_notify_connected()
|
||||
await client.process()
|
||||
finally:
|
||||
await client.stop()
|
||||
await stack.enter_async_context(
|
||||
coordinator.async_monitor_client()
|
||||
)
|
||||
|
||||
_LOGGER.debug("Client disconnected %s", client.host)
|
||||
for coordinator in coordinators.values():
|
||||
coordinator.async_notify_disconnected()
|
||||
await client.process()
|
||||
finally:
|
||||
_LOGGER.debug("Client disconnected %s", client.host)
|
||||
|
||||
except ConnectionFailed:
|
||||
await asyncio.sleep(interval)
|
||||
pass
|
||||
except TimeoutError:
|
||||
continue
|
||||
except Exception:
|
||||
_LOGGER.exception("Unexpected exception, aborting arcam client")
|
||||
return
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
@@ -2,11 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
from arcam.fmj import ConnectionFailed
|
||||
from arcam.fmj.client import Client
|
||||
from arcam.fmj.client import AmxDuetResponse, Client, ResponsePacket
|
||||
from arcam.fmj.state import State
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@@ -51,7 +53,7 @@ class ArcamFmjCoordinator(DataUpdateCoordinator[None]):
|
||||
)
|
||||
self.client = client
|
||||
self.state = State(client, zone)
|
||||
self.last_update_success = False
|
||||
self.update_in_progress = False
|
||||
|
||||
name = config_entry.title
|
||||
unique_id = config_entry.unique_id or config_entry.entry_id
|
||||
@@ -74,24 +76,34 @@ class ArcamFmjCoordinator(DataUpdateCoordinator[None]):
|
||||
async def _async_update_data(self) -> None:
|
||||
"""Fetch data for manual refresh."""
|
||||
try:
|
||||
self.update_in_progress = True
|
||||
await self.state.update()
|
||||
except ConnectionFailed as err:
|
||||
raise UpdateFailed(
|
||||
f"Connection failed during update for zone {self.state.zn}"
|
||||
) from err
|
||||
finally:
|
||||
self.update_in_progress = False
|
||||
|
||||
@callback
|
||||
def async_notify_data_updated(self) -> None:
|
||||
"""Notify that new data has been received from the device."""
|
||||
self.async_set_updated_data(None)
|
||||
def _async_notify_packet(self, packet: ResponsePacket | AmxDuetResponse) -> None:
|
||||
"""Packet callback to detect changes to state."""
|
||||
if (
|
||||
not isinstance(packet, ResponsePacket)
|
||||
or packet.zn != self.state.zn
|
||||
or self.update_in_progress
|
||||
):
|
||||
return
|
||||
|
||||
@callback
|
||||
def async_notify_connected(self) -> None:
|
||||
"""Handle client connected."""
|
||||
self.hass.async_create_task(self.async_refresh())
|
||||
|
||||
@callback
|
||||
def async_notify_disconnected(self) -> None:
|
||||
"""Handle client disconnected."""
|
||||
self.last_update_success = False
|
||||
self.async_update_listeners()
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_monitor_client(self) -> AsyncGenerator[None]:
|
||||
"""Monitor a client and state for changes while connected."""
|
||||
async with self.state:
|
||||
self.hass.async_create_task(self.async_refresh())
|
||||
try:
|
||||
with self.client.listen(self._async_notify_packet):
|
||||
yield
|
||||
finally:
|
||||
self.hass.async_create_task(self.async_refresh())
|
||||
|
||||
@@ -26,3 +26,8 @@ class ArcamFmjEntity(CoordinatorEntity[ArcamFmjCoordinator]):
|
||||
if description is not None:
|
||||
self._attr_unique_id = f"{self._attr_unique_id}-{description.key}"
|
||||
self.entity_description = description
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
"""Return if entity is available."""
|
||||
return super().available and self.coordinator.client.connected
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
"""Tests for the arcam_fmj component."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import Mock, patch
|
||||
from asyncio import CancelledError, Queue
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from arcam.fmj.client import Client
|
||||
from arcam.fmj.client import Client, ResponsePacket
|
||||
from arcam.fmj.state import State
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.arcam_fmj.const import DEFAULT_NAME
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
@@ -28,12 +30,50 @@ MOCK_CONFIG_ENTRY = {CONF_HOST: MOCK_HOST, CONF_PORT: MOCK_PORT}
|
||||
|
||||
|
||||
@pytest.fixture(name="client")
|
||||
def client_fixture() -> Mock:
|
||||
def client_fixture() -> Generator[Mock]:
|
||||
"""Get a mocked client."""
|
||||
client = Mock(Client)
|
||||
client.host = MOCK_HOST
|
||||
client.port = MOCK_PORT
|
||||
return client
|
||||
|
||||
queue = Queue[BaseException | None]()
|
||||
listeners = set()
|
||||
|
||||
async def _start():
|
||||
client.connected = True
|
||||
|
||||
async def _process():
|
||||
result = await queue.get()
|
||||
client.connected = False
|
||||
if isinstance(result, BaseException):
|
||||
raise result
|
||||
|
||||
@contextmanager
|
||||
def _listen(listener):
|
||||
listeners.add(listener)
|
||||
yield client
|
||||
listeners.remove(listener)
|
||||
|
||||
@callback
|
||||
def _notify_data_updated(zn=1):
|
||||
packet = Mock(ResponsePacket)
|
||||
packet.zn = zn
|
||||
for listener in listeners:
|
||||
listener(packet)
|
||||
|
||||
@callback
|
||||
def _notify_connection(exception: Exception | None = None):
|
||||
queue.put_nowait(exception)
|
||||
|
||||
client.start.side_effect = _start
|
||||
client.process.side_effect = _process
|
||||
client.listen.side_effect = _listen
|
||||
client.notify_data_updated = _notify_data_updated
|
||||
client.notify_connection = _notify_connection
|
||||
|
||||
yield client
|
||||
|
||||
queue.put_nowait(CancelledError())
|
||||
|
||||
|
||||
@pytest.fixture(name="state_1")
|
||||
@@ -52,6 +92,8 @@ def state_1_fixture(client: Mock) -> State:
|
||||
state.get_mute.return_value = None
|
||||
state.get_decode_modes.return_value = []
|
||||
state.get_decode_mode.return_value = None
|
||||
state.__aenter__ = AsyncMock()
|
||||
state.__aexit__ = AsyncMock()
|
||||
return state
|
||||
|
||||
|
||||
@@ -71,6 +113,8 @@ def state_2_fixture(client: Mock) -> State:
|
||||
state.get_mute.return_value = None
|
||||
state.get_decode_modes.return_value = []
|
||||
state.get_decode_mode.return_value = None
|
||||
state.__aenter__ = AsyncMock()
|
||||
state.__aexit__ = AsyncMock()
|
||||
return state
|
||||
|
||||
|
||||
@@ -104,18 +148,6 @@ async def player_setup_fixture(
|
||||
return state_2
|
||||
raise ValueError(f"Unknown player zone: {zone}")
|
||||
|
||||
async def _mock_run_client(hass: HomeAssistant, runtime_data, interval):
|
||||
coordinators = runtime_data.coordinators
|
||||
|
||||
def _notify_data_updated() -> None:
|
||||
for coordinator in coordinators.values():
|
||||
coordinator.async_notify_data_updated()
|
||||
|
||||
client.notify_data_updated = _notify_data_updated
|
||||
|
||||
for coordinator in coordinators.values():
|
||||
coordinator.async_notify_connected()
|
||||
|
||||
await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
with (
|
||||
@@ -124,10 +156,6 @@ async def player_setup_fixture(
|
||||
"homeassistant.components.arcam_fmj.coordinator.State",
|
||||
side_effect=state_mock,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.arcam_fmj._run_client",
|
||||
side_effect=_mock_run_client,
|
||||
),
|
||||
):
|
||||
assert await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
@@ -33,7 +33,7 @@ from homeassistant.components.media_player import (
|
||||
SERVICE_VOLUME_UP,
|
||||
MediaType,
|
||||
)
|
||||
from homeassistant.const import ATTR_ENTITY_ID, Platform
|
||||
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE, Platform
|
||||
from homeassistant.core import HomeAssistant, State as CoreState
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
@@ -62,6 +62,21 @@ async def test_setup(
|
||||
await snapshot_platform(hass, entity_registry, snapshot, mock_config_entry.entry_id)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("player_setup")
|
||||
async def test_disconnect(hass: HomeAssistant, client: Mock) -> None:
|
||||
"""Test a disconnection is detected."""
|
||||
data = hass.states.get(MOCK_ENTITY_ID)
|
||||
assert data
|
||||
assert data.state != STATE_UNAVAILABLE
|
||||
|
||||
client.notify_connection(ConnectionFailed())
|
||||
await hass.async_block_till_done()
|
||||
|
||||
data = hass.states.get(MOCK_ENTITY_ID)
|
||||
assert data
|
||||
assert data.state == STATE_UNAVAILABLE
|
||||
|
||||
|
||||
async def update(hass: HomeAssistant, client: Mock, entity_id: str) -> CoreState:
|
||||
"""Force a update of player and return current state data."""
|
||||
client.notify_data_updated()
|
||||
|
||||
Reference in New Issue
Block a user