mirror of
https://github.com/home-assistant/core.git
synced 2026-05-08 17:49:37 +01:00
Restart SimpliSafe websocket after request failures (#160974)
Co-authored-by: Joostlek <joostlek@outlook.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from simplipy import API
|
||||
from simplipy.errors import (
|
||||
EndpointUnavailableError,
|
||||
InvalidCredentialsError,
|
||||
RequestError,
|
||||
SimplipyError,
|
||||
WebsocketError,
|
||||
)
|
||||
@@ -46,10 +47,9 @@ from homeassistant.const import (
|
||||
CONF_CODE,
|
||||
CONF_TOKEN,
|
||||
CONF_USERNAME,
|
||||
EVENT_HOMEASSISTANT_STOP,
|
||||
Platform,
|
||||
)
|
||||
from homeassistant.core import CoreState, Event, HomeAssistant, ServiceCall, callback
|
||||
from homeassistant.core import CoreState, HomeAssistant, ServiceCall, callback
|
||||
from homeassistant.exceptions import (
|
||||
ConfigEntryAuthFailed,
|
||||
ConfigEntryNotReady,
|
||||
@@ -103,6 +103,7 @@ DEFAULT_SCAN_INTERVAL = timedelta(seconds=30)
|
||||
|
||||
WEBSOCKET_RECONNECT_RETRIES = 3
|
||||
WEBSOCKET_RETRY_DELAY = 2
|
||||
WEBSOCKET_LOOP_TASK_NAME = "simplisafe websocket task"
|
||||
|
||||
EVENT_SIMPLISAFE_EVENT = "SIMPLISAFE_EVENT"
|
||||
EVENT_SIMPLISAFE_NOTIFICATION = "SIMPLISAFE_NOTIFICATION"
|
||||
@@ -420,8 +421,7 @@ class SimpliSafe:
|
||||
self._api = api
|
||||
self._hass = hass
|
||||
self._system_notifications: dict[int, set[SystemNotification]] = {}
|
||||
self._websocket_reconnect_retries: int = 0
|
||||
self._websocket_reconnect_task: asyncio.Task | None = None
|
||||
self._websocket_task: asyncio.Task | None = None
|
||||
self.entry = entry
|
||||
self.initial_event_to_use: dict[int, dict[str, Any]] = {}
|
||||
self.subscription_data: dict[int, Any] = api.subscription_data
|
||||
@@ -467,53 +467,69 @@ class SimpliSafe:
|
||||
|
||||
self._system_notifications[system.system_id] = latest_notifications
|
||||
|
||||
async def _async_start_websocket_loop(self) -> None:
|
||||
"""Start a websocket reconnection loop."""
|
||||
assert self._api.websocket
|
||||
@callback
|
||||
def _async_start_websocket_if_needed(self) -> None:
|
||||
"""Start the websocket loop task if it isn't already running."""
|
||||
task = self._websocket_task
|
||||
|
||||
self._websocket_reconnect_retries += 1
|
||||
|
||||
try:
|
||||
await self._api.websocket.async_connect()
|
||||
await self._api.websocket.async_listen()
|
||||
except asyncio.CancelledError:
|
||||
LOGGER.debug("Request to cancel websocket loop received")
|
||||
raise
|
||||
except WebsocketError as err:
|
||||
LOGGER.error("Failed to connect to websocket: %s", err)
|
||||
except Exception as err: # noqa: BLE001
|
||||
LOGGER.error("Unknown exception while connecting to websocket: %s", err)
|
||||
else:
|
||||
self._websocket_reconnect_retries = 0
|
||||
|
||||
if self._websocket_reconnect_retries >= WEBSOCKET_RECONNECT_RETRIES:
|
||||
LOGGER.error("Max websocket connection retries exceeded")
|
||||
if task and not task.done():
|
||||
return
|
||||
|
||||
delay = WEBSOCKET_RETRY_DELAY * (2 ** (self._websocket_reconnect_retries - 1))
|
||||
LOGGER.info(
|
||||
"Retrying websocket connection in %s seconds (attempt %s/%s)",
|
||||
delay,
|
||||
self._websocket_reconnect_retries,
|
||||
WEBSOCKET_RECONNECT_RETRIES,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
self._websocket_reconnect_task = self._hass.async_create_task(
|
||||
self._async_start_websocket_loop()
|
||||
LOGGER.debug("Starting websocket loop task")
|
||||
|
||||
self._websocket_task = self.entry.async_create_background_task(
|
||||
self._hass, self._async_websocket_loop(), WEBSOCKET_LOOP_TASK_NAME
|
||||
)
|
||||
|
||||
async def _async_websocket_loop(self) -> None:
|
||||
assert self._api.websocket
|
||||
|
||||
retries = 0
|
||||
while True:
|
||||
try:
|
||||
await self._api.websocket.async_connect()
|
||||
await self._api.websocket.async_listen()
|
||||
except asyncio.CancelledError:
|
||||
await self._api.websocket.async_disconnect()
|
||||
raise
|
||||
except WebsocketError as err:
|
||||
retries += 1
|
||||
delay = WEBSOCKET_RETRY_DELAY * (2 ** (retries - 1))
|
||||
LOGGER.debug(
|
||||
"Websocket error (%s/%s): %s; retrying in %s seconds",
|
||||
retries,
|
||||
WEBSOCKET_RECONNECT_RETRIES,
|
||||
err,
|
||||
delay,
|
||||
)
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
if retries >= WEBSOCKET_RECONNECT_RETRIES:
|
||||
LOGGER.error(
|
||||
"Websocket connection failed, task exiting (%s/%s): %s",
|
||||
retries,
|
||||
WEBSOCKET_RECONNECT_RETRIES,
|
||||
err,
|
||||
)
|
||||
return
|
||||
except Exception as err: # noqa: BLE001
|
||||
# unexpected errors → log and stop
|
||||
LOGGER.exception("Unexpected error in websocket loop: %s", err)
|
||||
return
|
||||
|
||||
async def _async_cancel_websocket_loop(self) -> None:
|
||||
"""Stop any existing websocket reconnection loop."""
|
||||
if self._websocket_reconnect_task:
|
||||
self._websocket_reconnect_task.cancel()
|
||||
try:
|
||||
await self._websocket_reconnect_task
|
||||
except asyncio.CancelledError:
|
||||
LOGGER.debug("Websocket reconnection task successfully canceled")
|
||||
self._websocket_reconnect_task = None
|
||||
"""Cancel the websocket loop task, if running."""
|
||||
task = self._websocket_task
|
||||
if not task:
|
||||
return
|
||||
|
||||
assert self._api.websocket
|
||||
await self._api.websocket.async_disconnect()
|
||||
self._websocket_task = None
|
||||
task.cancel()
|
||||
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
LOGGER.debug("Websocket loop task cancelled")
|
||||
|
||||
@callback
|
||||
def _async_websocket_on_event(self, event: WebsocketEvent) -> None:
|
||||
@@ -553,20 +569,7 @@ class SimpliSafe:
|
||||
assert self._api.websocket
|
||||
|
||||
self._api.websocket.add_event_callback(self._async_websocket_on_event)
|
||||
self._websocket_reconnect_task = asyncio.create_task(
|
||||
self._async_start_websocket_loop()
|
||||
)
|
||||
|
||||
async def async_websocket_disconnect_listener(_: Event) -> None:
|
||||
"""Define an event handler to disconnect from the websocket."""
|
||||
assert self._api.websocket
|
||||
await self._async_cancel_websocket_loop()
|
||||
|
||||
self.entry.async_on_unload(
|
||||
self._hass.bus.async_listen_once(
|
||||
EVENT_HOMEASSISTANT_STOP, async_websocket_disconnect_listener
|
||||
)
|
||||
)
|
||||
self._async_start_websocket_if_needed()
|
||||
|
||||
self.systems = await self._api.async_get_systems()
|
||||
for system in self.systems.values():
|
||||
@@ -610,9 +613,7 @@ class SimpliSafe:
|
||||
# Open a new websocket connection with the fresh token:
|
||||
assert self._api.websocket
|
||||
await self._async_cancel_websocket_loop()
|
||||
self._websocket_reconnect_task = self._hass.async_create_task(
|
||||
self._async_start_websocket_loop()
|
||||
)
|
||||
self._async_start_websocket_if_needed()
|
||||
|
||||
self.entry.async_on_unload(
|
||||
self._api.add_refresh_token_callback(async_handle_refresh_token)
|
||||
@@ -625,22 +626,37 @@ class SimpliSafe:
|
||||
"""Get updated data from SimpliSafe."""
|
||||
|
||||
async def async_update_system(system: SystemType) -> None:
|
||||
"""Update a system."""
|
||||
"""Update a single system and process notifications."""
|
||||
await system.async_update(cached=system.version != 3)
|
||||
self._async_process_new_notifications(system)
|
||||
|
||||
tasks = [async_update_system(system) for system in self.systems.values()]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, InvalidCredentialsError):
|
||||
raise ConfigEntryAuthFailed("Invalid credentials") from result
|
||||
|
||||
if isinstance(result, EndpointUnavailableError):
|
||||
# In case the user attempts an action not allowed in their current plan,
|
||||
# we merely log that message at INFO level (so the user is aware,
|
||||
# but not spammed with ERROR messages that they cannot change):
|
||||
LOGGER.debug(result)
|
||||
|
||||
if isinstance(result, SimplipyError):
|
||||
raise UpdateFailed(f"SimpliSafe error while updating: {result}")
|
||||
try:
|
||||
# Gather all system updates; exceptions will propagate
|
||||
await asyncio.gather(*tasks)
|
||||
except InvalidCredentialsError as err:
|
||||
# Stop websocket immediately on auth failure
|
||||
if self._websocket_task:
|
||||
LOGGER.debug("Cancelling websocket loop due to invalid credentials")
|
||||
await self._async_cancel_websocket_loop()
|
||||
# Signal HA that credentials are invalid; user intervention is required
|
||||
raise ConfigEntryAuthFailed("Invalid credentials") from err
|
||||
except RequestError as err:
|
||||
# Cloud-level request errors: wrap aiohttp errors
|
||||
if self._websocket_task:
|
||||
LOGGER.debug("Cancelling websocket loop due to request error")
|
||||
await self._async_cancel_websocket_loop()
|
||||
raise UpdateFailed(
|
||||
f"Request error while updating all systems: {err}"
|
||||
) from err
|
||||
except EndpointUnavailableError as err:
|
||||
# Currently not raised by the API; included for future-proofing.
|
||||
# Informational per-system (e.g., user plan restrictions)
|
||||
LOGGER.debug("Endpoint unavailable: %s", err)
|
||||
except SimplipyError as err:
|
||||
# Any other SimplipyError not caught per-system
|
||||
raise UpdateFailed(f"SimpliSafe error while updating: {err}") from err
|
||||
else:
|
||||
# Successful update, try to restart websocket if necessary
|
||||
self._async_start_websocket_if_needed()
|
||||
|
||||
@@ -108,7 +108,7 @@ class SimpliSafeLock(SimpliSafeEntity, LockEntity):
|
||||
"""Update the entity when new data comes from the websocket."""
|
||||
assert event.event_type
|
||||
|
||||
if state := STATE_MAP_FROM_WEBSOCKET_EVENT.get(event.event_type) is not None:
|
||||
if (state := STATE_MAP_FROM_WEBSOCKET_EVENT.get(event.event_type)) is not None:
|
||||
self._attr_is_locked = state
|
||||
self.async_reset_error_count()
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Define test fixtures for SimpliSafe."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -87,7 +86,7 @@ def data_settings_fixture() -> JsonObjectType:
|
||||
def data_subscription_fixture() -> JsonObjectType:
|
||||
"""Define subscription data."""
|
||||
data = load_json_object_fixture("subscription_data.json", "simplisafe")
|
||||
return {SYSTEM_ID: data}
|
||||
return {SYSTEM_ID: data} # type: ignore[return-value]
|
||||
|
||||
|
||||
@pytest.fixture(name="reauth_config")
|
||||
@@ -98,11 +97,9 @@ def reauth_config_fixture() -> dict[str, str]:
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(name="setup_simplisafe")
|
||||
async def setup_simplisafe_fixture(
|
||||
hass: HomeAssistant, api: Mock, config: dict[str, str]
|
||||
) -> AsyncGenerator[None]:
|
||||
"""Define a fixture to set up SimpliSafe."""
|
||||
@pytest.fixture(name="patch_simplisafe_api")
|
||||
def patch_simplisafe_api_fixture(api: Mock, websocket: Mock):
|
||||
"""Patch the SimpliSafe API creation methods."""
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.simplisafe.config_flow.API.async_from_auth",
|
||||
@@ -117,18 +114,22 @@ async def setup_simplisafe_fixture(
|
||||
return_value=api,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.simplisafe.SimpliSafe._async_start_websocket_loop"
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.simplisafe.PLATFORMS",
|
||||
[],
|
||||
"homeassistant.components.simplisafe.SimpliSafe._async_start_websocket_if_needed",
|
||||
),
|
||||
):
|
||||
assert await async_setup_component(hass, DOMAIN, config)
|
||||
await hass.async_block_till_done()
|
||||
api.websocket = websocket
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(name="setup_simplisafe")
|
||||
async def setup_simplisafe_fixture(
|
||||
hass: HomeAssistant, api: Mock, config: dict[str, str], patch_simplisafe_api
|
||||
) -> None:
|
||||
"""Define a fixture to set up SimpliSafe for config flow tests."""
|
||||
assert await async_setup_component(hass, DOMAIN, config)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.fixture(name="sms_config")
|
||||
def sms_config_fixture() -> dict[str, str]:
|
||||
"""Define a SMS-based two-factor authentication config."""
|
||||
@@ -150,6 +151,7 @@ def system_v3_fixture(
|
||||
system.sensor_data = data_sensor
|
||||
system.settings_data = data_settings
|
||||
system.generate_device_objects()
|
||||
system.async_update = AsyncMock(return_value=None)
|
||||
return system
|
||||
|
||||
|
||||
|
||||
@@ -1,50 +1,134 @@
|
||||
"""Define tests for SimpliSafe setup."""
|
||||
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
from homeassistant.components.simplisafe import DOMAIN
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
import pytest
|
||||
from simplipy.errors import (
|
||||
EndpointUnavailableError,
|
||||
InvalidCredentialsError,
|
||||
RequestError,
|
||||
SimplipyError,
|
||||
)
|
||||
from simplipy.websocket import WebsocketEvent
|
||||
|
||||
from homeassistant.components.simplisafe import DEFAULT_SCAN_INTERVAL, DOMAIN
|
||||
from homeassistant.config_entries import SOURCE_REAUTH
|
||||
from homeassistant.const import STATE_UNAVAILABLE
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry, async_fire_time_changed
|
||||
|
||||
|
||||
async def test_base_station_migration(
|
||||
hass: HomeAssistant, device_registry: dr.DeviceRegistry, api, config, config_entry
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
api: Mock,
|
||||
config: dict[str, str],
|
||||
config_entry: MockConfigEntry,
|
||||
patch_simplisafe_api,
|
||||
) -> None:
|
||||
"""Test that errors are shown when duplicates are added."""
|
||||
old_identifers = (DOMAIN, 12345)
|
||||
new_identifiers = (DOMAIN, "12345")
|
||||
"""Test that old integer-based device identifiers are migrated to strings."""
|
||||
old_identifiers = {(DOMAIN, 12345)}
|
||||
new_identifiers = {(DOMAIN, "12345")}
|
||||
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
identifiers={old_identifers},
|
||||
identifiers=old_identifiers,
|
||||
manufacturer="SimpliSafe",
|
||||
name="old",
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.simplisafe.config_flow.API.async_from_auth",
|
||||
return_value=api,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.simplisafe.API.async_from_auth",
|
||||
return_value=api,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.simplisafe.API.async_from_refresh_token",
|
||||
return_value=api,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.simplisafe.SimpliSafe._async_start_websocket_loop"
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.simplisafe.PLATFORMS",
|
||||
[],
|
||||
),
|
||||
):
|
||||
assert await async_setup_component(hass, DOMAIN, config)
|
||||
await hass.async_block_till_done()
|
||||
assert await async_setup_component(hass, DOMAIN, config)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert device_registry.async_get_device(identifiers={old_identifers}) is None
|
||||
assert device_registry.async_get_device(identifiers={new_identifiers}) is not None
|
||||
assert device_registry.async_get_device(identifiers=old_identifiers) is None
|
||||
assert device_registry.async_get_device(identifiers=new_identifiers) is not None
|
||||
|
||||
|
||||
async def test_coordinator_update_triggers_reauth_on_invalid_credentials(
|
||||
hass: HomeAssistant,
|
||||
config_entry: MockConfigEntry,
|
||||
patch_simplisafe_api,
|
||||
system_v3,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
) -> None:
|
||||
"""Test that InvalidCredentialsError triggers a reauth flow."""
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
system_v3.async_update = AsyncMock(side_effect=InvalidCredentialsError("fail"))
|
||||
|
||||
freezer.tick(DEFAULT_SCAN_INTERVAL)
|
||||
async_fire_time_changed(hass)
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
assert len(flows) == 1
|
||||
flow = flows[0]
|
||||
assert flow.get("context", {}).get("source") == SOURCE_REAUTH
|
||||
assert flow.get("context", {}).get("entry_id") == config_entry.entry_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exc",
|
||||
[RequestError, EndpointUnavailableError, SimplipyError],
|
||||
)
|
||||
async def test_coordinator_update_failure_keeps_entity_available(
|
||||
hass: HomeAssistant,
|
||||
config_entry: MockConfigEntry,
|
||||
patch_simplisafe_api,
|
||||
system_v3,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
exc: type[SimplipyError],
|
||||
) -> None:
|
||||
"""Test that a single coordinator failure does not immediately mark entities unavailable."""
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert hass.states.get("lock.front_door_lock").state != STATE_UNAVAILABLE
|
||||
|
||||
system_v3.async_update = AsyncMock(side_effect=exc("fail"))
|
||||
|
||||
# Trigger one coordinator failure: error_count goes from 0 to 1, below threshold.
|
||||
freezer.tick(DEFAULT_SCAN_INTERVAL)
|
||||
async_fire_time_changed(hass)
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
|
||||
assert hass.states.get("lock.front_door_lock").state != STATE_UNAVAILABLE
|
||||
|
||||
|
||||
async def test_websocket_event_updates_entity_state(
|
||||
hass: HomeAssistant,
|
||||
config_entry: MockConfigEntry,
|
||||
patch_simplisafe_api,
|
||||
websocket: Mock,
|
||||
) -> None:
|
||||
"""Test that a push update from the websocket changes entity state."""
|
||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Retrieve the event callback that was registered with the mock websocket.
|
||||
assert websocket.add_event_callback.called
|
||||
event_callback = websocket.add_event_callback.call_args[0][0]
|
||||
|
||||
assert hass.states.get("lock.front_door_lock").state == "locked"
|
||||
|
||||
# Fire an "unlock" websocket event for the test lock (system_id=12345, serial="987").
|
||||
# CID 9700 maps to EVENT_LOCK_UNLOCKED in the simplipy event mapping.
|
||||
event_callback(
|
||||
WebsocketEvent(
|
||||
event_cid=9700,
|
||||
info="Lock unlocked",
|
||||
system_id=12345,
|
||||
_raw_timestamp=0,
|
||||
_video=None,
|
||||
_vid=None,
|
||||
sensor_serial="987",
|
||||
)
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert hass.states.get("lock.front_door_lock").state == "unlocked"
|
||||
|
||||
Reference in New Issue
Block a user