diff --git a/homeassistant/components/google_sheets/__init__.py b/homeassistant/components/google_sheets/__init__.py index ff0ce62ec24..99981348151 100644 --- a/homeassistant/components/google_sheets/__init__.py +++ b/homeassistant/components/google_sheets/__init__.py @@ -7,7 +7,12 @@ import aiohttp from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_TOKEN from homeassistant.core import HomeAssistant -from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady +from homeassistant.exceptions import ( + ConfigEntryAuthFailed, + ConfigEntryNotReady, + OAuth2TokenRequestError, + OAuth2TokenRequestReauthError, +) from homeassistant.helpers import config_validation as cv from homeassistant.helpers.config_entry_oauth2_flow import ( OAuth2Session, @@ -39,11 +44,11 @@ async def async_setup_entry( session = OAuth2Session(hass, entry, implementation) try: await session.async_ensure_token_valid() - except aiohttp.ClientResponseError as err: - if 400 <= err.status < 500: - raise ConfigEntryAuthFailed( - "OAuth session is not valid, reauth required" - ) from err + except OAuth2TokenRequestReauthError as err: + raise ConfigEntryAuthFailed( + "OAuth session is not valid, reauth required" + ) from err + except OAuth2TokenRequestError as err: raise ConfigEntryNotReady from err except aiohttp.ClientError as err: raise ConfigEntryNotReady from err diff --git a/tests/components/google_sheets/test_init.py b/tests/components/google_sheets/test_init.py index e3fa4842f19..7bb7369c7b5 100644 --- a/tests/components/google_sheets/test_init.py +++ b/tests/components/google_sheets/test_init.py @@ -4,7 +4,7 @@ from collections.abc import Awaitable, Callable, Coroutine import http import time from typing import Any -from unittest.mock import patch +from unittest.mock import Mock, patch from freezegun import freeze_time from gspread.exceptions import APIError @@ -29,7 +29,12 @@ from homeassistant.components.google_sheets.services import ( ) from homeassistant.config_entries import ConfigEntryState from homeassistant.core import HomeAssistant -from homeassistant.exceptions import HomeAssistantError, ServiceValidationError +from homeassistant.exceptions import ( + HomeAssistantError, + OAuth2TokenRequestReauthError, + OAuth2TokenRequestTransientError, + ServiceValidationError, +) from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry @@ -199,6 +204,64 @@ async def test_expired_token_refresh_failure( assert entries[0].state is expected_state +async def test_setup_oauth_reauth_error( + hass: HomeAssistant, config_entry: MockConfigEntry +) -> None: + """Test a token refresh reauth error puts the config entry in setup error state.""" + config_entry.add_to_hass(hass) + + assert await async_setup_component(hass, APPLICATION_CREDENTIALS_DOMAIN, {}) + await async_import_client_credential( + hass, + DOMAIN, + ClientCredential("client-id", "client-secret"), + DOMAIN, + ) + + with ( + patch.object(config_entry, "async_start_reauth") as mock_async_start_reauth, + patch( + "homeassistant.components.google_sheets.OAuth2Session.async_ensure_token_valid", + side_effect=OAuth2TokenRequestReauthError( + domain=DOMAIN, request_info=Mock() + ), + ), + ): + await hass.config_entries.async_setup(config_entry.entry_id) + + await hass.async_block_till_done() + + assert config_entry.state is ConfigEntryState.SETUP_ERROR + mock_async_start_reauth.assert_called_once_with(hass) + + +async def test_setup_oauth_transient_error( + hass: HomeAssistant, config_entry: MockConfigEntry +) -> None: + """Test a token refresh transient error sets the config entry to retry setup.""" + config_entry.add_to_hass(hass) + + assert await async_setup_component(hass, APPLICATION_CREDENTIALS_DOMAIN, {}) + await async_import_client_credential( + hass, + DOMAIN, + ClientCredential("client-id", "client-secret"), + DOMAIN, + ) + + with patch( + "homeassistant.components.google_sheets.OAuth2Session.async_ensure_token_valid", + side_effect=OAuth2TokenRequestTransientError( + domain=DOMAIN, request_info=Mock() + ), + ): + await hass.config_entries.async_setup(config_entry.entry_id) + + await hass.async_block_till_done() + + assert config_entry.state is ConfigEntryState.SETUP_RETRY + + @pytest.mark.parametrize( ("add_created_column_param", "expected_row"), [