mirror of
https://github.com/home-assistant/core.git
synced 2026-05-21 16:00:12 +01:00
Adjust condition API (#170486)
This commit is contained in:
@@ -85,7 +85,7 @@ class DeviceCondition(Condition):
|
||||
assert config.options is not None
|
||||
self._config = config.options
|
||||
|
||||
async def async_setup(self) -> None:
|
||||
async def _async_setup(self) -> None:
|
||||
"""Set up a device condition."""
|
||||
platform = await async_get_device_automation_platform(
|
||||
self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
||||
|
||||
@@ -22,6 +22,7 @@ from typing import (
|
||||
TypedDict,
|
||||
Unpack,
|
||||
cast,
|
||||
final,
|
||||
overload,
|
||||
override,
|
||||
)
|
||||
@@ -292,10 +293,12 @@ _CONDITION_SCHEMA = _CONDITION_BASE_SCHEMA.extend(
|
||||
class ConditionChecker(abc.ABC):
|
||||
"""Base class for condition checkers."""
|
||||
|
||||
_set_up = False
|
||||
_unloaded = False
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize condition checker."""
|
||||
self._hass = hass
|
||||
self._unloaded = False
|
||||
|
||||
def __call__(
|
||||
self, hass: HomeAssistant, variables: TemplateVarsType = None
|
||||
@@ -315,29 +318,45 @@ class ConditionChecker(abc.ABC):
|
||||
except Exception:
|
||||
_LOGGER.exception("Error while unloading condition checker")
|
||||
|
||||
@final
|
||||
async def async_setup(self) -> None:
|
||||
"""Set up the condition checker.
|
||||
|
||||
Users of conditions do not need to call this method directly. It is called
|
||||
automatically by async_from_config and async_conditions_from_config.
|
||||
"""
|
||||
await self._async_setup()
|
||||
self._set_up = True
|
||||
|
||||
async def _async_setup(self) -> None:
|
||||
"""Set up the condition checker.
|
||||
|
||||
Intended to be overridden in derived classes that need to do setup.
|
||||
"""
|
||||
|
||||
@final
|
||||
def async_unload(self) -> None:
|
||||
"""Clean up any resources held by the checker.
|
||||
|
||||
Users of conditions must call this method when they are done with the
|
||||
checker to ensure resources are released.
|
||||
"""
|
||||
self._async_unload()
|
||||
self._unloaded = True
|
||||
|
||||
def _async_unload(self) -> None:
|
||||
"""Clean up any resources held by the checker.
|
||||
|
||||
Intended to be overridden in derived classes that need to do unloading.
|
||||
"""
|
||||
self._unloaded = True
|
||||
|
||||
@final
|
||||
def async_check(
|
||||
self, *, variables: TemplateVarsType = None, **kwargs: Never
|
||||
) -> bool | None:
|
||||
"""Check the condition."""
|
||||
if not self._set_up:
|
||||
raise HomeAssistantError("Condition checker is not set up")
|
||||
with trace_condition(variables):
|
||||
result = self._async_check(variables=variables)
|
||||
condition_trace_update_result(result=result)
|
||||
@@ -375,11 +394,10 @@ class CompoundConditionChecker(ConditionChecker):
|
||||
super().__init__(hass)
|
||||
self._conditions = conditions
|
||||
|
||||
def async_unload(self) -> None:
|
||||
def _async_unload(self) -> None:
|
||||
"""Clean up child conditions."""
|
||||
for condition in self._conditions:
|
||||
condition.async_unload()
|
||||
super().async_unload()
|
||||
|
||||
|
||||
class Condition(ConditionChecker):
|
||||
@@ -523,9 +541,8 @@ class EntityConditionBase(Condition):
|
||||
self._valid_since.pop(entity_id, None)
|
||||
|
||||
@override
|
||||
async def async_setup(self) -> None:
|
||||
async def _async_setup(self) -> None:
|
||||
"""Set up state tracking for duration-based conditions."""
|
||||
await super().async_setup()
|
||||
if not self._duration or not self._needs_duration_tracking:
|
||||
return
|
||||
|
||||
@@ -559,9 +576,8 @@ class EntityConditionBase(Condition):
|
||||
self._on_unload.append(unsub)
|
||||
|
||||
@override
|
||||
def async_unload(self) -> None:
|
||||
def _async_unload(self) -> None:
|
||||
"""Unsubscribe from listeners."""
|
||||
super().async_unload()
|
||||
for cb in self._on_unload:
|
||||
cb()
|
||||
self._on_unload.clear()
|
||||
@@ -1086,7 +1102,9 @@ async def async_from_config(
|
||||
f"Error rendering condition enabled template: {err}"
|
||||
) from err
|
||||
if not enabled:
|
||||
return DisabledConditionChecker(hass)
|
||||
disabled_checker = DisabledConditionChecker(hass)
|
||||
await disabled_checker.async_setup()
|
||||
return disabled_checker
|
||||
|
||||
condition_key: str = config[CONF_CONDITION]
|
||||
factory: Any = None
|
||||
@@ -1119,14 +1137,15 @@ async def async_from_config(
|
||||
while isinstance(check_factory, ft.partial):
|
||||
check_factory = check_factory.func
|
||||
|
||||
checker: ConditionChecker | ConditionCheckerType
|
||||
if inspect.iscoroutinefunction(check_factory):
|
||||
checker = await factory(hass, config)
|
||||
else:
|
||||
checker = factory(config)
|
||||
if isinstance(checker, ConditionChecker):
|
||||
await checker.async_setup()
|
||||
return checker
|
||||
return LegacyConditionChecker(hass, cast(ConditionCheckerType, checker))
|
||||
if not isinstance(checker, ConditionChecker):
|
||||
checker = LegacyConditionChecker(hass, checker)
|
||||
await checker.async_setup()
|
||||
return checker
|
||||
|
||||
|
||||
async def async_and_from_config(
|
||||
|
||||
@@ -4455,6 +4455,7 @@ async def test_condition_checker_call_calls_async_check(
|
||||
return True
|
||||
|
||||
checker = MockChecker(hass)
|
||||
await checker.async_setup()
|
||||
check_mock = Mock(wraps=checker.async_check)
|
||||
checker.async_check = check_mock
|
||||
|
||||
@@ -5362,13 +5363,6 @@ async def test_async_from_config_calls_async_setup_on_checker(
|
||||
class StubChecker(condition.ConditionChecker):
|
||||
"""Stub checker to track async_setup calls."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
super().__init__(hass)
|
||||
self.setup_called = False
|
||||
|
||||
async def async_setup(self) -> None:
|
||||
self.setup_called = True
|
||||
|
||||
def _async_check(self, **kwargs: Any) -> bool:
|
||||
return True
|
||||
|
||||
@@ -5390,4 +5384,77 @@ async def test_async_from_config_calls_async_setup_on_checker(
|
||||
result = await condition.async_from_config(hass, config)
|
||||
|
||||
assert result is stub
|
||||
assert stub.setup_called
|
||||
assert stub._set_up is True
|
||||
|
||||
|
||||
async def test_async_setup_invokes_async_setup_hook(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test that async_setup awaits _async_setup and sets _set_up."""
|
||||
|
||||
setup_hook = AsyncMock()
|
||||
|
||||
class MockChecker(ConditionChecker):
|
||||
async def _async_setup(self) -> None:
|
||||
await setup_hook()
|
||||
|
||||
def _async_check(self, **kwargs: Any) -> bool:
|
||||
return True
|
||||
|
||||
checker = MockChecker(hass)
|
||||
|
||||
assert checker._set_up is False
|
||||
setup_hook.assert_not_called()
|
||||
|
||||
await checker.async_setup()
|
||||
|
||||
setup_hook.assert_awaited_once()
|
||||
assert checker._set_up is True
|
||||
|
||||
|
||||
async def test_async_check_raises_before_setup(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test that async_check raises HomeAssistantError before async_setup is called."""
|
||||
|
||||
class MockChecker(ConditionChecker):
|
||||
def _async_check(self, **kwargs: Any) -> bool:
|
||||
return True
|
||||
|
||||
checker = MockChecker(hass)
|
||||
|
||||
with pytest.raises(HomeAssistantError, match="not set up"):
|
||||
checker.async_check()
|
||||
|
||||
with pytest.raises(HomeAssistantError, match="not set up"):
|
||||
checker(hass)
|
||||
|
||||
await checker.async_setup()
|
||||
|
||||
assert checker.async_check() is True
|
||||
assert checker(hass) is True
|
||||
|
||||
|
||||
async def test_async_unload_invokes_async_unload_hook(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test that async_unload calls _async_unload and sets _unloaded."""
|
||||
|
||||
unload_hook = Mock()
|
||||
|
||||
class MockChecker(ConditionChecker):
|
||||
def _async_unload(self) -> None:
|
||||
unload_hook()
|
||||
|
||||
def _async_check(self, **kwargs: Any) -> bool:
|
||||
return True
|
||||
|
||||
checker = MockChecker(hass)
|
||||
|
||||
assert checker._unloaded is False
|
||||
unload_hook.assert_not_called()
|
||||
|
||||
checker.async_unload()
|
||||
|
||||
unload_hook.assert_called_once()
|
||||
assert checker._unloaded is True
|
||||
|
||||
Reference in New Issue
Block a user