From 011ebec00143a9e72a96a2dba6dde27824f8f4b3 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 13 May 2026 16:02:45 +0200 Subject: [PATCH] Adjust condition API (#170486) --- .../components/device_automation/condition.py | 2 +- homeassistant/helpers/condition.py | 45 +++++++--- tests/helpers/test_condition.py | 83 +++++++++++++++++-- 3 files changed, 108 insertions(+), 22 deletions(-) diff --git a/homeassistant/components/device_automation/condition.py b/homeassistant/components/device_automation/condition.py index c4987b49ffc..314ae8722a5 100644 --- a/homeassistant/components/device_automation/condition.py +++ b/homeassistant/components/device_automation/condition.py @@ -85,7 +85,7 @@ class DeviceCondition(Condition): assert config.options is not None self._config = config.options - async def async_setup(self) -> None: + async def _async_setup(self) -> None: """Set up a device condition.""" platform = await async_get_device_automation_platform( self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index a20f0449b4d..027aeb21f13 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -22,6 +22,7 @@ from typing import ( TypedDict, Unpack, cast, + final, overload, override, ) @@ -292,10 +293,12 @@ _CONDITION_SCHEMA = _CONDITION_BASE_SCHEMA.extend( class ConditionChecker(abc.ABC): """Base class for condition checkers.""" + _set_up = False + _unloaded = False + def __init__(self, hass: HomeAssistant) -> None: """Initialize condition checker.""" self._hass = hass - self._unloaded = False def __call__( self, hass: HomeAssistant, variables: TemplateVarsType = None @@ -315,29 +318,45 @@ class ConditionChecker(abc.ABC): except Exception: _LOGGER.exception("Error while unloading condition checker") + @final async def async_setup(self) -> None: """Set up the condition checker. Users of conditions do not need to call this method directly. It is called automatically by async_from_config and async_conditions_from_config. + """ + await self._async_setup() + self._set_up = True + + async def _async_setup(self) -> None: + """Set up the condition checker. Intended to be overridden in derived classes that need to do setup. """ + @final def async_unload(self) -> None: """Clean up any resources held by the checker. Users of conditions must call this method when they are done with the checker to ensure resources are released. + """ + self._async_unload() + self._unloaded = True + + def _async_unload(self) -> None: + """Clean up any resources held by the checker. Intended to be overridden in derived classes that need to do unloading. """ - self._unloaded = True + @final def async_check( self, *, variables: TemplateVarsType = None, **kwargs: Never ) -> bool | None: """Check the condition.""" + if not self._set_up: + raise HomeAssistantError("Condition checker is not set up") with trace_condition(variables): result = self._async_check(variables=variables) condition_trace_update_result(result=result) @@ -375,11 +394,10 @@ class CompoundConditionChecker(ConditionChecker): super().__init__(hass) self._conditions = conditions - def async_unload(self) -> None: + def _async_unload(self) -> None: """Clean up child conditions.""" for condition in self._conditions: condition.async_unload() - super().async_unload() class Condition(ConditionChecker): @@ -523,9 +541,8 @@ class EntityConditionBase(Condition): self._valid_since.pop(entity_id, None) @override - async def async_setup(self) -> None: + async def _async_setup(self) -> None: """Set up state tracking for duration-based conditions.""" - await super().async_setup() if not self._duration or not self._needs_duration_tracking: return @@ -559,9 +576,8 @@ class EntityConditionBase(Condition): self._on_unload.append(unsub) @override - def async_unload(self) -> None: + def _async_unload(self) -> None: """Unsubscribe from listeners.""" - super().async_unload() for cb in self._on_unload: cb() self._on_unload.clear() @@ -1086,7 +1102,9 @@ async def async_from_config( f"Error rendering condition enabled template: {err}" ) from err if not enabled: - return DisabledConditionChecker(hass) + disabled_checker = DisabledConditionChecker(hass) + await disabled_checker.async_setup() + return disabled_checker condition_key: str = config[CONF_CONDITION] factory: Any = None @@ -1119,14 +1137,15 @@ async def async_from_config( while isinstance(check_factory, ft.partial): check_factory = check_factory.func + checker: ConditionChecker | ConditionCheckerType if inspect.iscoroutinefunction(check_factory): checker = await factory(hass, config) else: checker = factory(config) - if isinstance(checker, ConditionChecker): - await checker.async_setup() - return checker - return LegacyConditionChecker(hass, cast(ConditionCheckerType, checker)) + if not isinstance(checker, ConditionChecker): + checker = LegacyConditionChecker(hass, checker) + await checker.async_setup() + return checker async def async_and_from_config( diff --git a/tests/helpers/test_condition.py b/tests/helpers/test_condition.py index bb8b437ae37..3b2607f3f7f 100644 --- a/tests/helpers/test_condition.py +++ b/tests/helpers/test_condition.py @@ -4455,6 +4455,7 @@ async def test_condition_checker_call_calls_async_check( return True checker = MockChecker(hass) + await checker.async_setup() check_mock = Mock(wraps=checker.async_check) checker.async_check = check_mock @@ -5362,13 +5363,6 @@ async def test_async_from_config_calls_async_setup_on_checker( class StubChecker(condition.ConditionChecker): """Stub checker to track async_setup calls.""" - def __init__(self, hass: HomeAssistant) -> None: - super().__init__(hass) - self.setup_called = False - - async def async_setup(self) -> None: - self.setup_called = True - def _async_check(self, **kwargs: Any) -> bool: return True @@ -5390,4 +5384,77 @@ async def test_async_from_config_calls_async_setup_on_checker( result = await condition.async_from_config(hass, config) assert result is stub - assert stub.setup_called + assert stub._set_up is True + + +async def test_async_setup_invokes_async_setup_hook( + hass: HomeAssistant, +) -> None: + """Test that async_setup awaits _async_setup and sets _set_up.""" + + setup_hook = AsyncMock() + + class MockChecker(ConditionChecker): + async def _async_setup(self) -> None: + await setup_hook() + + def _async_check(self, **kwargs: Any) -> bool: + return True + + checker = MockChecker(hass) + + assert checker._set_up is False + setup_hook.assert_not_called() + + await checker.async_setup() + + setup_hook.assert_awaited_once() + assert checker._set_up is True + + +async def test_async_check_raises_before_setup( + hass: HomeAssistant, +) -> None: + """Test that async_check raises HomeAssistantError before async_setup is called.""" + + class MockChecker(ConditionChecker): + def _async_check(self, **kwargs: Any) -> bool: + return True + + checker = MockChecker(hass) + + with pytest.raises(HomeAssistantError, match="not set up"): + checker.async_check() + + with pytest.raises(HomeAssistantError, match="not set up"): + checker(hass) + + await checker.async_setup() + + assert checker.async_check() is True + assert checker(hass) is True + + +async def test_async_unload_invokes_async_unload_hook( + hass: HomeAssistant, +) -> None: + """Test that async_unload calls _async_unload and sets _unloaded.""" + + unload_hook = Mock() + + class MockChecker(ConditionChecker): + def _async_unload(self) -> None: + unload_hook() + + def _async_check(self, **kwargs: Any) -> bool: + return True + + checker = MockChecker(hass) + + assert checker._unloaded is False + unload_hook.assert_not_called() + + checker.async_unload() + + unload_hook.assert_called_once() + assert checker._unloaded is True