diff --git a/homeassistant/components/homewizard/config_flow.py b/homeassistant/components/homewizard/config_flow.py index dc9b6b61640..d46b827c88f 100644 --- a/homeassistant/components/homewizard/config_flow.py +++ b/homeassistant/components/homewizard/config_flow.py @@ -10,7 +10,7 @@ from homewizard_energy.errors import DisabledError, RequestError, UnsupportedErr from homewizard_energy.models import Device from voluptuous import Required, Schema -from homeassistant.components import zeroconf +from homeassistant.components import onboarding, zeroconf from homeassistant.config_entries import ConfigEntry, ConfigFlow from homeassistant.const import CONF_IP_ADDRESS from homeassistant.data_entry_flow import AbortFlow, FlowResult @@ -113,7 +113,7 @@ class HomeWizardConfigFlow(ConfigFlow, domain=DOMAIN): ) -> FlowResult: """Confirm discovery.""" errors: dict[str, str] | None = None - if user_input is not None: + if user_input is not None or not onboarding.async_is_onboarded(self.hass): try: await self._async_try_connect(self.discovery.ip) except RecoverableError as ex: diff --git a/tests/components/homewizard/conftest.py b/tests/components/homewizard/conftest.py index c9a04c55dae..b1bfb1190dc 100644 --- a/tests/components/homewizard/conftest.py +++ b/tests/components/homewizard/conftest.py @@ -1,6 +1,7 @@ """Fixtures for HomeWizard integration tests.""" +from collections.abc import Generator import json -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from homewizard_energy.features import Features from homewizard_energy.models import Data, Device, State, System @@ -80,3 +81,13 @@ async def init_integration( await hass.async_block_till_done() return mock_config_entry + + +@pytest.fixture +def mock_onboarding() -> Generator[MagicMock, None, None]: + """Mock that Home Assistant is currently onboarding.""" + with patch( + "homeassistant.components.onboarding.async_is_onboarded", + return_value=False, + ) as mock_onboarding: + yield mock_onboarding diff --git a/tests/components/homewizard/test_config_flow.py b/tests/components/homewizard/test_config_flow.py index 106687f0b01..9b6648af3d3 100644 --- a/tests/components/homewizard/test_config_flow.py +++ b/tests/components/homewizard/test_config_flow.py @@ -1,6 +1,5 @@ """Test the homewizard config flow.""" -import logging -from unittest.mock import patch +from unittest.mock import MagicMock, patch from homewizard_energy.errors import DisabledError, RequestError, UnsupportedError @@ -13,8 +12,7 @@ from homeassistant.data_entry_flow import FlowResultType from .generator import get_mock_device from tests.common import MockConfigEntry - -_LOGGER = logging.getLogger(__name__) +from tests.test_util.aiohttp import AiohttpClientMocker async def test_manual_flow_works(hass, aioclient_mock): @@ -112,6 +110,114 @@ async def test_discovery_flow_works(hass, aioclient_mock): assert result["result"].unique_id == "HWE-P1_aabbccddeeff" +async def test_discovery_flow_during_onboarding( + hass, aioclient_mock: AiohttpClientMocker, mock_onboarding: MagicMock +) -> None: + """Test discovery setup flow during onboarding.""" + + with patch( + "homeassistant.components.homewizard.async_setup_entry", + return_value=True, + ) as mock_setup_entry, patch( + "homeassistant.components.homewizard.config_flow.HomeWizardEnergy", + return_value=get_mock_device(), + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_ZEROCONF}, + data=zeroconf.ZeroconfServiceInfo( + host="192.168.43.183", + addresses=["192.168.43.183"], + port=80, + hostname="p1meter-ddeeff.local.", + type="mock_type", + name="mock_name", + properties={ + "api_enabled": "1", + "path": "/api/v1", + "product_name": "P1 meter", + "product_type": "HWE-P1", + "serial": "aabbccddeeff", + }, + ), + ) + + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["title"] == "P1 meter (aabbccddeeff)" + assert result["data"][CONF_IP_ADDRESS] == "192.168.43.183" + + assert result["result"] + assert result["result"].unique_id == "HWE-P1_aabbccddeeff" + + assert len(mock_setup_entry.mock_calls) == 1 + assert len(mock_onboarding.mock_calls) == 1 + + +async def test_discovery_flow_during_onboarding_disabled_api( + hass, aioclient_mock: AiohttpClientMocker, mock_onboarding: MagicMock +) -> None: + """Test discovery setup flow during onboarding with a disabled API.""" + + def mock_initialize(): + raise DisabledError + + device = get_mock_device() + device.device.side_effect = mock_initialize + + with patch( + "homeassistant.components.homewizard.config_flow.HomeWizardEnergy", + return_value=device, + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_ZEROCONF}, + data=zeroconf.ZeroconfServiceInfo( + host="192.168.43.183", + addresses=["192.168.43.183"], + port=80, + hostname="p1meter-ddeeff.local.", + type="mock_type", + name="mock_name", + properties={ + "api_enabled": "0", + "path": "/api/v1", + "product_name": "P1 meter", + "product_type": "HWE-P1", + "serial": "aabbccddeeff", + }, + ), + ) + + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "discovery_confirm" + assert result["errors"] == {"base": "api_not_enabled"} + + # We are onboarded, user enabled API again and picks up from discovery/config flow + device.device.side_effect = None + mock_onboarding.return_value = True + + with patch( + "homeassistant.components.homewizard.async_setup_entry", + return_value=True, + ) as mock_setup_entry, patch( + "homeassistant.components.homewizard.config_flow.HomeWizardEnergy", + return_value=device, + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={"ip_address": "192.168.43.183"} + ) + + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["title"] == "P1 meter (aabbccddeeff)" + assert result["data"][CONF_IP_ADDRESS] == "192.168.43.183" + + assert result["result"] + assert result["result"].unique_id == "HWE-P1_aabbccddeeff" + + assert len(mock_setup_entry.mock_calls) == 1 + assert len(mock_onboarding.mock_calls) == 1 + + async def test_discovery_disabled_api(hass, aioclient_mock): """Test discovery detecting disabled api."""