diff --git a/homeassistant/components/simplisafe/__init__.py b/homeassistant/components/simplisafe/__init__.py index 8e964e0c776..d0238323568 100644 --- a/homeassistant/components/simplisafe/__init__.py +++ b/homeassistant/components/simplisafe/__init__.py @@ -11,6 +11,7 @@ from simplipy import API from simplipy.errors import ( EndpointUnavailableError, InvalidCredentialsError, + RequestError, SimplipyError, WebsocketError, ) @@ -46,10 +47,9 @@ from homeassistant.const import ( CONF_CODE, CONF_TOKEN, CONF_USERNAME, - EVENT_HOMEASSISTANT_STOP, Platform, ) -from homeassistant.core import CoreState, Event, HomeAssistant, ServiceCall, callback +from homeassistant.core import CoreState, HomeAssistant, ServiceCall, callback from homeassistant.exceptions import ( ConfigEntryAuthFailed, ConfigEntryNotReady, @@ -103,6 +103,7 @@ DEFAULT_SCAN_INTERVAL = timedelta(seconds=30) WEBSOCKET_RECONNECT_RETRIES = 3 WEBSOCKET_RETRY_DELAY = 2 +WEBSOCKET_LOOP_TASK_NAME = "simplisafe websocket task" EVENT_SIMPLISAFE_EVENT = "SIMPLISAFE_EVENT" EVENT_SIMPLISAFE_NOTIFICATION = "SIMPLISAFE_NOTIFICATION" @@ -420,8 +421,7 @@ class SimpliSafe: self._api = api self._hass = hass self._system_notifications: dict[int, set[SystemNotification]] = {} - self._websocket_reconnect_retries: int = 0 - self._websocket_reconnect_task: asyncio.Task | None = None + self._websocket_task: asyncio.Task | None = None self.entry = entry self.initial_event_to_use: dict[int, dict[str, Any]] = {} self.subscription_data: dict[int, Any] = api.subscription_data @@ -467,53 +467,69 @@ class SimpliSafe: self._system_notifications[system.system_id] = latest_notifications - async def _async_start_websocket_loop(self) -> None: - """Start a websocket reconnection loop.""" - assert self._api.websocket + @callback + def _async_start_websocket_if_needed(self) -> None: + """Start the websocket loop task if it isn't already running.""" + task = self._websocket_task - self._websocket_reconnect_retries += 1 - - try: - await self._api.websocket.async_connect() - await self._api.websocket.async_listen() - except asyncio.CancelledError: - LOGGER.debug("Request to cancel websocket loop received") - raise - except WebsocketError as err: - LOGGER.error("Failed to connect to websocket: %s", err) - except Exception as err: # noqa: BLE001 - LOGGER.error("Unknown exception while connecting to websocket: %s", err) - else: - self._websocket_reconnect_retries = 0 - - if self._websocket_reconnect_retries >= WEBSOCKET_RECONNECT_RETRIES: - LOGGER.error("Max websocket connection retries exceeded") + if task and not task.done(): return - delay = WEBSOCKET_RETRY_DELAY * (2 ** (self._websocket_reconnect_retries - 1)) - LOGGER.info( - "Retrying websocket connection in %s seconds (attempt %s/%s)", - delay, - self._websocket_reconnect_retries, - WEBSOCKET_RECONNECT_RETRIES, - ) - await asyncio.sleep(delay) - self._websocket_reconnect_task = self._hass.async_create_task( - self._async_start_websocket_loop() + LOGGER.debug("Starting websocket loop task") + + self._websocket_task = self.entry.async_create_background_task( + self._hass, self._async_websocket_loop(), WEBSOCKET_LOOP_TASK_NAME ) + async def _async_websocket_loop(self) -> None: + assert self._api.websocket + + retries = 0 + while True: + try: + await self._api.websocket.async_connect() + await self._api.websocket.async_listen() + except asyncio.CancelledError: + await self._api.websocket.async_disconnect() + raise + except WebsocketError as err: + retries += 1 + delay = WEBSOCKET_RETRY_DELAY * (2 ** (retries - 1)) + LOGGER.debug( + "Websocket error (%s/%s): %s; retrying in %s seconds", + retries, + WEBSOCKET_RECONNECT_RETRIES, + err, + delay, + ) + + await asyncio.sleep(delay) + if retries >= WEBSOCKET_RECONNECT_RETRIES: + LOGGER.error( + "Websocket connection failed, task exiting (%s/%s): %s", + retries, + WEBSOCKET_RECONNECT_RETRIES, + err, + ) + return + except Exception as err: # noqa: BLE001 + # unexpected errors → log and stop + LOGGER.exception("Unexpected error in websocket loop: %s", err) + return + async def _async_cancel_websocket_loop(self) -> None: - """Stop any existing websocket reconnection loop.""" - if self._websocket_reconnect_task: - self._websocket_reconnect_task.cancel() - try: - await self._websocket_reconnect_task - except asyncio.CancelledError: - LOGGER.debug("Websocket reconnection task successfully canceled") - self._websocket_reconnect_task = None + """Cancel the websocket loop task, if running.""" + task = self._websocket_task + if not task: + return - assert self._api.websocket - await self._api.websocket.async_disconnect() + self._websocket_task = None + task.cancel() + + try: + await task + except asyncio.CancelledError: + LOGGER.debug("Websocket loop task cancelled") @callback def _async_websocket_on_event(self, event: WebsocketEvent) -> None: @@ -553,20 +569,7 @@ class SimpliSafe: assert self._api.websocket self._api.websocket.add_event_callback(self._async_websocket_on_event) - self._websocket_reconnect_task = asyncio.create_task( - self._async_start_websocket_loop() - ) - - async def async_websocket_disconnect_listener(_: Event) -> None: - """Define an event handler to disconnect from the websocket.""" - assert self._api.websocket - await self._async_cancel_websocket_loop() - - self.entry.async_on_unload( - self._hass.bus.async_listen_once( - EVENT_HOMEASSISTANT_STOP, async_websocket_disconnect_listener - ) - ) + self._async_start_websocket_if_needed() self.systems = await self._api.async_get_systems() for system in self.systems.values(): @@ -610,9 +613,7 @@ class SimpliSafe: # Open a new websocket connection with the fresh token: assert self._api.websocket await self._async_cancel_websocket_loop() - self._websocket_reconnect_task = self._hass.async_create_task( - self._async_start_websocket_loop() - ) + self._async_start_websocket_if_needed() self.entry.async_on_unload( self._api.add_refresh_token_callback(async_handle_refresh_token) @@ -625,22 +626,37 @@ class SimpliSafe: """Get updated data from SimpliSafe.""" async def async_update_system(system: SystemType) -> None: - """Update a system.""" + """Update a single system and process notifications.""" await system.async_update(cached=system.version != 3) self._async_process_new_notifications(system) tasks = [async_update_system(system) for system in self.systems.values()] - results = await asyncio.gather(*tasks, return_exceptions=True) - for result in results: - if isinstance(result, InvalidCredentialsError): - raise ConfigEntryAuthFailed("Invalid credentials") from result - - if isinstance(result, EndpointUnavailableError): - # In case the user attempts an action not allowed in their current plan, - # we merely log that message at INFO level (so the user is aware, - # but not spammed with ERROR messages that they cannot change): - LOGGER.debug(result) - - if isinstance(result, SimplipyError): - raise UpdateFailed(f"SimpliSafe error while updating: {result}") + try: + # Gather all system updates; exceptions will propagate + await asyncio.gather(*tasks) + except InvalidCredentialsError as err: + # Stop websocket immediately on auth failure + if self._websocket_task: + LOGGER.debug("Cancelling websocket loop due to invalid credentials") + await self._async_cancel_websocket_loop() + # Signal HA that credentials are invalid; user intervention is required + raise ConfigEntryAuthFailed("Invalid credentials") from err + except RequestError as err: + # Cloud-level request errors: wrap aiohttp errors + if self._websocket_task: + LOGGER.debug("Cancelling websocket loop due to request error") + await self._async_cancel_websocket_loop() + raise UpdateFailed( + f"Request error while updating all systems: {err}" + ) from err + except EndpointUnavailableError as err: + # Currently not raised by the API; included for future-proofing. + # Informational per-system (e.g., user plan restrictions) + LOGGER.debug("Endpoint unavailable: %s", err) + except SimplipyError as err: + # Any other SimplipyError not caught per-system + raise UpdateFailed(f"SimpliSafe error while updating: {err}") from err + else: + # Successful update, try to restart websocket if necessary + self._async_start_websocket_if_needed() diff --git a/homeassistant/components/simplisafe/lock.py b/homeassistant/components/simplisafe/lock.py index 9e29bb2051b..a0626898a21 100644 --- a/homeassistant/components/simplisafe/lock.py +++ b/homeassistant/components/simplisafe/lock.py @@ -108,7 +108,7 @@ class SimpliSafeLock(SimpliSafeEntity, LockEntity): """Update the entity when new data comes from the websocket.""" assert event.event_type - if state := STATE_MAP_FROM_WEBSOCKET_EVENT.get(event.event_type) is not None: + if (state := STATE_MAP_FROM_WEBSOCKET_EVENT.get(event.event_type)) is not None: self._attr_is_locked = state self.async_reset_error_count() else: diff --git a/tests/components/simplisafe/conftest.py b/tests/components/simplisafe/conftest.py index 12ed845c7d2..3b002cf07d5 100644 --- a/tests/components/simplisafe/conftest.py +++ b/tests/components/simplisafe/conftest.py @@ -1,6 +1,5 @@ """Define test fixtures for SimpliSafe.""" -from collections.abc import AsyncGenerator from unittest.mock import AsyncMock, Mock, patch import pytest @@ -87,7 +86,7 @@ def data_settings_fixture() -> JsonObjectType: def data_subscription_fixture() -> JsonObjectType: """Define subscription data.""" data = load_json_object_fixture("subscription_data.json", "simplisafe") - return {SYSTEM_ID: data} + return {SYSTEM_ID: data} # type: ignore[return-value] @pytest.fixture(name="reauth_config") @@ -98,11 +97,9 @@ def reauth_config_fixture() -> dict[str, str]: } -@pytest.fixture(name="setup_simplisafe") -async def setup_simplisafe_fixture( - hass: HomeAssistant, api: Mock, config: dict[str, str] -) -> AsyncGenerator[None]: - """Define a fixture to set up SimpliSafe.""" +@pytest.fixture(name="patch_simplisafe_api") +def patch_simplisafe_api_fixture(api: Mock, websocket: Mock): + """Patch the SimpliSafe API creation methods.""" with ( patch( "homeassistant.components.simplisafe.config_flow.API.async_from_auth", @@ -117,18 +114,22 @@ async def setup_simplisafe_fixture( return_value=api, ), patch( - "homeassistant.components.simplisafe.SimpliSafe._async_start_websocket_loop" - ), - patch( - "homeassistant.components.simplisafe.PLATFORMS", - [], + "homeassistant.components.simplisafe.SimpliSafe._async_start_websocket_if_needed", ), ): - assert await async_setup_component(hass, DOMAIN, config) - await hass.async_block_till_done() + api.websocket = websocket yield +@pytest.fixture(name="setup_simplisafe") +async def setup_simplisafe_fixture( + hass: HomeAssistant, api: Mock, config: dict[str, str], patch_simplisafe_api +) -> None: + """Define a fixture to set up SimpliSafe for config flow tests.""" + assert await async_setup_component(hass, DOMAIN, config) + await hass.async_block_till_done() + + @pytest.fixture(name="sms_config") def sms_config_fixture() -> dict[str, str]: """Define a SMS-based two-factor authentication config.""" @@ -150,6 +151,7 @@ def system_v3_fixture( system.sensor_data = data_sensor system.settings_data = data_settings system.generate_device_objects() + system.async_update = AsyncMock(return_value=None) return system diff --git a/tests/components/simplisafe/test_init.py b/tests/components/simplisafe/test_init.py index 130ce59cd4a..c449f8a5602 100644 --- a/tests/components/simplisafe/test_init.py +++ b/tests/components/simplisafe/test_init.py @@ -1,50 +1,134 @@ """Define tests for SimpliSafe setup.""" -from unittest.mock import patch +from unittest.mock import AsyncMock, Mock -from homeassistant.components.simplisafe import DOMAIN +from freezegun.api import FrozenDateTimeFactory +import pytest +from simplipy.errors import ( + EndpointUnavailableError, + InvalidCredentialsError, + RequestError, + SimplipyError, +) +from simplipy.websocket import WebsocketEvent + +from homeassistant.components.simplisafe import DEFAULT_SCAN_INTERVAL, DOMAIN +from homeassistant.config_entries import SOURCE_REAUTH +from homeassistant.const import STATE_UNAVAILABLE from homeassistant.core import HomeAssistant from homeassistant.helpers import device_registry as dr from homeassistant.setup import async_setup_component +from tests.common import MockConfigEntry, async_fire_time_changed + async def test_base_station_migration( - hass: HomeAssistant, device_registry: dr.DeviceRegistry, api, config, config_entry + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + api: Mock, + config: dict[str, str], + config_entry: MockConfigEntry, + patch_simplisafe_api, ) -> None: - """Test that errors are shown when duplicates are added.""" - old_identifers = (DOMAIN, 12345) - new_identifiers = (DOMAIN, "12345") + """Test that old integer-based device identifiers are migrated to strings.""" + old_identifiers = {(DOMAIN, 12345)} + new_identifiers = {(DOMAIN, "12345")} device_registry.async_get_or_create( config_entry_id=config_entry.entry_id, - identifiers={old_identifers}, + identifiers=old_identifiers, manufacturer="SimpliSafe", name="old", ) - with ( - patch( - "homeassistant.components.simplisafe.config_flow.API.async_from_auth", - return_value=api, - ), - patch( - "homeassistant.components.simplisafe.API.async_from_auth", - return_value=api, - ), - patch( - "homeassistant.components.simplisafe.API.async_from_refresh_token", - return_value=api, - ), - patch( - "homeassistant.components.simplisafe.SimpliSafe._async_start_websocket_loop" - ), - patch( - "homeassistant.components.simplisafe.PLATFORMS", - [], - ), - ): - assert await async_setup_component(hass, DOMAIN, config) - await hass.async_block_till_done() + assert await async_setup_component(hass, DOMAIN, config) + await hass.async_block_till_done() - assert device_registry.async_get_device(identifiers={old_identifers}) is None - assert device_registry.async_get_device(identifiers={new_identifiers}) is not None + assert device_registry.async_get_device(identifiers=old_identifiers) is None + assert device_registry.async_get_device(identifiers=new_identifiers) is not None + + +async def test_coordinator_update_triggers_reauth_on_invalid_credentials( + hass: HomeAssistant, + config_entry: MockConfigEntry, + patch_simplisafe_api, + system_v3, + freezer: FrozenDateTimeFactory, +) -> None: + """Test that InvalidCredentialsError triggers a reauth flow.""" + await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + system_v3.async_update = AsyncMock(side_effect=InvalidCredentialsError("fail")) + + freezer.tick(DEFAULT_SCAN_INTERVAL) + async_fire_time_changed(hass) + await hass.async_block_till_done(wait_background_tasks=True) + + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 1 + flow = flows[0] + assert flow.get("context", {}).get("source") == SOURCE_REAUTH + assert flow.get("context", {}).get("entry_id") == config_entry.entry_id + + +@pytest.mark.parametrize( + "exc", + [RequestError, EndpointUnavailableError, SimplipyError], +) +async def test_coordinator_update_failure_keeps_entity_available( + hass: HomeAssistant, + config_entry: MockConfigEntry, + patch_simplisafe_api, + system_v3, + freezer: FrozenDateTimeFactory, + exc: type[SimplipyError], +) -> None: + """Test that a single coordinator failure does not immediately mark entities unavailable.""" + await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + assert hass.states.get("lock.front_door_lock").state != STATE_UNAVAILABLE + + system_v3.async_update = AsyncMock(side_effect=exc("fail")) + + # Trigger one coordinator failure: error_count goes from 0 to 1, below threshold. + freezer.tick(DEFAULT_SCAN_INTERVAL) + async_fire_time_changed(hass) + await hass.async_block_till_done(wait_background_tasks=True) + + assert hass.states.get("lock.front_door_lock").state != STATE_UNAVAILABLE + + +async def test_websocket_event_updates_entity_state( + hass: HomeAssistant, + config_entry: MockConfigEntry, + patch_simplisafe_api, + websocket: Mock, +) -> None: + """Test that a push update from the websocket changes entity state.""" + await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + # Retrieve the event callback that was registered with the mock websocket. + assert websocket.add_event_callback.called + event_callback = websocket.add_event_callback.call_args[0][0] + + assert hass.states.get("lock.front_door_lock").state == "locked" + + # Fire an "unlock" websocket event for the test lock (system_id=12345, serial="987"). + # CID 9700 maps to EVENT_LOCK_UNLOCKED in the simplipy event mapping. + event_callback( + WebsocketEvent( + event_cid=9700, + info="Lock unlocked", + system_id=12345, + _raw_timestamp=0, + _video=None, + _vid=None, + sensor_serial="987", + ) + ) + await hass.async_block_till_done() + + assert hass.states.get("lock.front_door_lock").state == "unlocked"