diff --git a/homeassistant/components/tessie/config_flow.py b/homeassistant/components/tessie/config_flow.py index 14c6b93fdfd..fc350856b0f 100644 --- a/homeassistant/components/tessie/config_flow.py +++ b/homeassistant/components/tessie/config_flow.py @@ -3,15 +3,16 @@ from __future__ import annotations from collections.abc import Mapping -from http import HTTPStatus from typing import Any -from aiohttp import ClientConnectionError, ClientResponseError -from tessie_api import get_state_of_all_vehicles +from aiohttp import ClientConnectionError +from tesla_fleet_api.exceptions import InvalidToken, MissingToken, TeslaFleetError +from tesla_fleet_api.tessie import Tessie import voluptuous as vol from homeassistant.config_entries import ConfigFlow, ConfigFlowResult from homeassistant.const import CONF_ACCESS_TOKEN +from homeassistant.core import HomeAssistant from homeassistant.helpers.aiohttp_client import async_get_clientsession from .const import DOMAIN @@ -23,6 +24,24 @@ DESCRIPTION_PLACEHOLDERS = { } +async def _async_validate_access_token( + hass: HomeAssistant, access_token: str, *, only_active: bool = False +) -> dict[str, str]: + """Validate a Tessie access token.""" + try: + await Tessie(async_get_clientsession(hass), access_token).list_vehicles( + only_active=only_active + ) + except InvalidToken, MissingToken: + return {CONF_ACCESS_TOKEN: "invalid_access_token"} + except ClientConnectionError: + return {"base": "cannot_connect"} + except TeslaFleetError: + return {"base": "unknown"} + + return {} + + class TessieConfigFlow(ConfigFlow, domain=DOMAIN): """Config Tessie API connection.""" @@ -35,20 +54,10 @@ class TessieConfigFlow(ConfigFlow, domain=DOMAIN): errors: dict[str, str] = {} if user_input: self._async_abort_entries_match(dict(user_input)) - try: - await get_state_of_all_vehicles( - session=async_get_clientsession(self.hass), - api_key=user_input[CONF_ACCESS_TOKEN], - only_active=True, - ) - except ClientResponseError as e: - if e.status == HTTPStatus.UNAUTHORIZED: - errors[CONF_ACCESS_TOKEN] = "invalid_access_token" - else: - errors["base"] = "unknown" - except ClientConnectionError: - errors["base"] = "cannot_connect" - else: + errors = await _async_validate_access_token( + self.hass, user_input[CONF_ACCESS_TOKEN], only_active=True + ) + if not errors: return self.async_create_entry( title="Tessie", data=user_input, @@ -74,19 +83,10 @@ class TessieConfigFlow(ConfigFlow, domain=DOMAIN): errors: dict[str, str] = {} if user_input: - try: - await get_state_of_all_vehicles( - session=async_get_clientsession(self.hass), - api_key=user_input[CONF_ACCESS_TOKEN], - ) - except ClientResponseError as e: - if e.status == HTTPStatus.UNAUTHORIZED: - errors[CONF_ACCESS_TOKEN] = "invalid_access_token" - else: - errors["base"] = "unknown" - except ClientConnectionError: - errors["base"] = "cannot_connect" - else: + errors = await _async_validate_access_token( + self.hass, user_input[CONF_ACCESS_TOKEN] + ) + if not errors: return self.async_update_reload_and_abort( self._get_reauth_entry(), data=user_input ) diff --git a/tests/components/tessie/test_config_flow.py b/tests/components/tessie/test_config_flow.py index d51d467002d..a958467374d 100644 --- a/tests/components/tessie/test_config_flow.py +++ b/tests/components/tessie/test_config_flow.py @@ -1,8 +1,10 @@ """Test the Tessie config flow.""" -from unittest.mock import patch +from collections.abc import Iterator +from unittest.mock import AsyncMock, patch import pytest +from tesla_fleet_api.exceptions import InvalidToken, MissingToken, TeslaFleetError from homeassistant import config_entries from homeassistant.components.tessie.const import DOMAIN @@ -10,29 +12,23 @@ from homeassistant.const import CONF_ACCESS_TOKEN from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType -from .common import ( - ERROR_AUTH, - ERROR_CONNECTION, - ERROR_UNKNOWN, - TEST_CONFIG, - TEST_STATE_OF_ALL_VEHICLES, -) +from .common import ERROR_CONNECTION, TEST_CONFIG, TEST_STATE_OF_ALL_VEHICLES from tests.common import MockConfigEntry @pytest.fixture(autouse=True) -def mock_config_flow_get_state_of_all_vehicles(): - """Mock get_state_of_all_vehicles in config flow.""" +def mock_config_flow_list_vehicles() -> Iterator[AsyncMock]: + """Mock Tessie.list_vehicles in config flow.""" with patch( - "homeassistant.components.tessie.config_flow.get_state_of_all_vehicles", + "homeassistant.components.tessie.config_flow.Tessie.list_vehicles", return_value=TEST_STATE_OF_ALL_VEHICLES, - ) as mock_config_flow_get_state_of_all_vehicles: - yield mock_config_flow_get_state_of_all_vehicles + ) as mock_list_vehicles: + yield mock_list_vehicles @pytest.fixture(autouse=True) -def mock_async_setup_entry(): +def mock_async_setup_entry() -> Iterator[AsyncMock]: """Mock async_setup_entry.""" with patch( "homeassistant.components.tessie.async_setup_entry", @@ -43,8 +39,8 @@ def mock_async_setup_entry(): async def test_form( hass: HomeAssistant, - mock_config_flow_get_state_of_all_vehicles, - mock_async_setup_entry, + mock_config_flow_list_vehicles: AsyncMock, + mock_async_setup_entry: AsyncMock, ) -> None: """Test we get the form.""" @@ -60,7 +56,7 @@ async def test_form( ) await hass.async_block_till_done() assert len(mock_async_setup_entry.mock_calls) == 1 - assert len(mock_config_flow_get_state_of_all_vehicles.mock_calls) == 1 + assert len(mock_config_flow_list_vehicles.mock_calls) == 1 assert result2["type"] is FlowResultType.CREATE_ENTRY assert result2["title"] == "Tessie" @@ -69,8 +65,6 @@ async def test_form( async def test_abort( hass: HomeAssistant, - mock_config_flow_get_state_of_all_vehicles, - mock_async_setup_entry, ) -> None: """Test a duplicate entry aborts.""" @@ -97,13 +91,17 @@ async def test_abort( @pytest.mark.parametrize( ("side_effect", "error"), [ - (ERROR_AUTH, {CONF_ACCESS_TOKEN: "invalid_access_token"}), - (ERROR_UNKNOWN, {"base": "unknown"}), + (InvalidToken(), {CONF_ACCESS_TOKEN: "invalid_access_token"}), + (MissingToken(), {CONF_ACCESS_TOKEN: "invalid_access_token"}), + (TeslaFleetError(), {"base": "unknown"}), (ERROR_CONNECTION, {"base": "cannot_connect"}), ], ) async def test_form_errors( - hass: HomeAssistant, side_effect, error, mock_config_flow_get_state_of_all_vehicles + hass: HomeAssistant, + side_effect: BaseException, + error: dict[str, str], + mock_config_flow_list_vehicles: AsyncMock, ) -> None: """Test errors are handled.""" @@ -111,7 +109,7 @@ async def test_form_errors( DOMAIN, context={"source": config_entries.SOURCE_USER} ) - mock_config_flow_get_state_of_all_vehicles.side_effect = side_effect + mock_config_flow_list_vehicles.side_effect = side_effect result2 = await hass.config_entries.flow.async_configure( result1["flow_id"], TEST_CONFIG, @@ -121,7 +119,7 @@ async def test_form_errors( assert result2["errors"] == error # Complete the flow - mock_config_flow_get_state_of_all_vehicles.side_effect = None + mock_config_flow_list_vehicles.side_effect = None result3 = await hass.config_entries.flow.async_configure( result2["flow_id"], TEST_CONFIG, @@ -132,8 +130,8 @@ async def test_form_errors( async def test_reauth( hass: HomeAssistant, - mock_config_flow_get_state_of_all_vehicles, - mock_async_setup_entry, + mock_config_flow_list_vehicles: AsyncMock, + mock_async_setup_entry: AsyncMock, ) -> None: """Test reauth flow.""" @@ -155,7 +153,7 @@ async def test_reauth( ) await hass.async_block_till_done() assert len(mock_async_setup_entry.mock_calls) == 1 - assert len(mock_config_flow_get_state_of_all_vehicles.mock_calls) == 1 + assert len(mock_config_flow_list_vehicles.mock_calls) == 1 assert result2["type"] is FlowResultType.ABORT assert result2["reason"] == "reauth_successful" @@ -165,21 +163,22 @@ async def test_reauth( @pytest.mark.parametrize( ("side_effect", "error"), [ - (ERROR_AUTH, {CONF_ACCESS_TOKEN: "invalid_access_token"}), - (ERROR_UNKNOWN, {"base": "unknown"}), + (InvalidToken(), {CONF_ACCESS_TOKEN: "invalid_access_token"}), + (MissingToken(), {CONF_ACCESS_TOKEN: "invalid_access_token"}), + (TeslaFleetError(), {"base": "unknown"}), (ERROR_CONNECTION, {"base": "cannot_connect"}), ], ) async def test_reauth_errors( hass: HomeAssistant, - mock_config_flow_get_state_of_all_vehicles, - mock_async_setup_entry, - side_effect, - error, + mock_config_flow_list_vehicles: AsyncMock, + mock_async_setup_entry: AsyncMock, + side_effect: BaseException, + error: dict[str, str], ) -> None: """Test reauth flows that fail.""" - mock_config_flow_get_state_of_all_vehicles.side_effect = side_effect + mock_config_flow_list_vehicles.side_effect = side_effect mock_entry = MockConfigEntry( domain=DOMAIN, @@ -199,7 +198,7 @@ async def test_reauth_errors( assert result2["errors"] == error # Complete the flow - mock_config_flow_get_state_of_all_vehicles.side_effect = None + mock_config_flow_list_vehicles.side_effect = None result3 = await hass.config_entries.flow.async_configure( result2["flow_id"], TEST_CONFIG,