1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-23 20:39:01 +00:00

Modernise condition checker in helper (#159159)

This commit is contained in:
Artur Pragacz
2025-12-16 10:46:10 +01:00
committed by GitHub
parent dbfdaf6a2e
commit 2d33a720f7
7 changed files with 99 additions and 54 deletions

View File

@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol from typing import Any, Protocol
import voluptuous as vol import voluptuous as vol
@@ -11,18 +11,15 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.condition import ( from homeassistant.helpers.condition import (
Condition, Condition,
ConditionChecker,
ConditionCheckerType, ConditionCheckerType,
ConditionConfig, ConditionConfig,
trace_condition_function,
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from . import DeviceAutomationType, async_get_device_automation_platform from . import DeviceAutomationType, async_get_device_automation_platform
from .helpers import async_validate_device_automation_config from .helpers import async_validate_device_automation_config
if TYPE_CHECKING:
from homeassistant.helpers import condition
class DeviceAutomationConditionProtocol(Protocol): class DeviceAutomationConditionProtocol(Protocol):
"""Define the format of device_condition modules. """Define the format of device_condition modules.
@@ -90,15 +87,21 @@ class DeviceCondition(Condition):
assert config.options is not None assert config.options is not None
self._config = config.options self._config = config.options
async def async_get_checker(self) -> condition.ConditionCheckerType: async def async_get_checker(self) -> ConditionChecker:
"""Test a device condition.""" """Test a device condition."""
platform = await async_get_device_automation_platform( platform = await async_get_device_automation_platform(
self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION
) )
return trace_condition_function( platform_checker = platform.async_condition_from_config(
platform.async_condition_from_config(self._hass, self._config) self._hass, self._config
) )
def checker(variables: TemplateVarsType = None, **kwargs: Any) -> bool:
result = platform_checker(self._hass, variables)
return result is not False
return checker
CONDITIONS: dict[str, type[Condition]] = { CONDITIONS: dict[str, type[Condition]] = {
"_device": DeviceCondition, "_device": DeviceCondition,

View File

@@ -1,7 +1,7 @@
"""Provides conditions for lights.""" """Provides conditions for lights."""
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Final, override from typing import TYPE_CHECKING, Any, Final, Unpack, override
import voluptuous as vol import voluptuous as vol
@@ -10,11 +10,11 @@ from homeassistant.core import HomeAssistant, split_entity_id
from homeassistant.helpers import config_validation as cv, target from homeassistant.helpers import config_validation as cv, target
from homeassistant.helpers.condition import ( from homeassistant.helpers.condition import (
Condition, Condition,
ConditionCheckerType, ConditionChecker,
ConditionCheckParams,
ConditionConfig, ConditionConfig,
trace_condition_function,
) )
from homeassistant.helpers.typing import ConfigType, TemplateVarsType from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN from .const import DOMAIN
@@ -61,7 +61,7 @@ class StateConditionBase(Condition):
self._state = state self._state = state
@override @override
async def async_get_checker(self) -> ConditionCheckerType: async def async_get_checker(self) -> ConditionChecker:
"""Get the condition checker.""" """Get the condition checker."""
def check_any_match_state(states: list[str]) -> bool: def check_any_match_state(states: list[str]) -> bool:
@@ -78,12 +78,11 @@ class StateConditionBase(Condition):
elif self._behavior == BEHAVIOR_ALL: elif self._behavior == BEHAVIOR_ALL:
matcher = check_all_match_state matcher = check_all_match_state
@trace_condition_function def test_state(**kwargs: Unpack[ConditionCheckParams]) -> bool:
def test_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Test state condition.""" """Test state condition."""
target_selection = target.TargetSelection(self._target) target_selection = target.TargetSelection(self._target)
targeted_entities = target.async_extract_referenced_entity_ids( targeted_entities = target.async_extract_referenced_entity_ids(
hass, target_selection, expand_group=False self._hass, target_selection, expand_group=False
) )
referenced_entity_ids = targeted_entities.referenced.union( referenced_entity_ids = targeted_entities.referenced.union(
targeted_entities.indirectly_referenced targeted_entities.indirectly_referenced
@@ -96,7 +95,7 @@ class StateConditionBase(Condition):
light_entity_states = [ light_entity_states = [
state.state state.state
for entity_id in light_entity_ids for entity_id in light_entity_ids
if (state := hass.states.get(entity_id)) if (state := self._hass.states.get(entity_id))
and state.state in STATE_CONDITION_VALID_STATES and state.state in STATE_CONDITION_VALID_STATES
] ]
return matcher(light_entity_states) return matcher(light_entity_states)

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, cast from typing import Any, Unpack, cast
import voluptuous as vol import voluptuous as vol
@@ -13,14 +13,14 @@ from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
from homeassistant.helpers.condition import ( from homeassistant.helpers.condition import (
Condition, Condition,
ConditionCheckerType, ConditionChecker,
ConditionCheckParams,
ConditionConfig, ConditionConfig,
condition_trace_set_result, condition_trace_set_result,
condition_trace_update_result, condition_trace_update_result,
trace_condition_function,
) )
from homeassistant.helpers.sun import get_astral_event_date from homeassistant.helpers.sun import get_astral_event_date
from homeassistant.helpers.typing import ConfigType, TemplateVarsType from homeassistant.helpers.typing import ConfigType
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
_OPTIONS_SCHEMA_DICT: dict[vol.Marker, Any] = { _OPTIONS_SCHEMA_DICT: dict[vol.Marker, Any] = {
@@ -154,17 +154,16 @@ class SunCondition(Condition):
assert config.options is not None assert config.options is not None
self._options = config.options self._options = config.options
async def async_get_checker(self) -> ConditionCheckerType: async def async_get_checker(self) -> ConditionChecker:
"""Wrap action method with sun based condition.""" """Wrap action method with sun based condition."""
before = self._options.get("before") before = self._options.get("before")
after = self._options.get("after") after = self._options.get("after")
before_offset = self._options.get("before_offset") before_offset = self._options.get("before_offset")
after_offset = self._options.get("after_offset") after_offset = self._options.get("after_offset")
@trace_condition_function def sun_if(**kwargs: Unpack[ConditionCheckParams]) -> bool:
def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Validate time based if-condition.""" """Validate time based if-condition."""
return sun(hass, before, after, before_offset, after_offset) return sun(self._hass, before, after, before_offset, after_offset)
return sun_if return sun_if

View File

@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, cast from typing import Any, Unpack, cast
import voluptuous as vol import voluptuous as vol
@@ -22,11 +22,11 @@ from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
from homeassistant.helpers.condition import ( from homeassistant.helpers.condition import (
Condition, Condition,
ConditionCheckerType, ConditionChecker,
ConditionCheckParams,
ConditionConfig, ConditionConfig,
trace_condition_function,
) )
from homeassistant.helpers.typing import ConfigType, TemplateVarsType from homeassistant.helpers.typing import ConfigType
from . import in_zone from . import in_zone
@@ -118,13 +118,12 @@ class ZoneCondition(Condition):
assert config.options is not None assert config.options is not None
self._options = config.options self._options = config.options
async def async_get_checker(self) -> ConditionCheckerType: async def async_get_checker(self) -> ConditionChecker:
"""Wrap action method with zone based condition.""" """Wrap action method with zone based condition."""
entity_ids = self._options.get(CONF_ENTITY_ID, []) entity_ids = self._options.get(CONF_ENTITY_ID, [])
zone_entity_ids = self._options.get(CONF_ZONE, []) zone_entity_ids = self._options.get(CONF_ZONE, [])
@trace_condition_function def if_in_zone(**kwargs: Unpack[ConditionCheckParams]) -> bool:
def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Test if condition.""" """Test if condition."""
errors = [] errors = []
@@ -133,7 +132,7 @@ class ZoneCondition(Condition):
entity_ok = False entity_ok = False
for zone_entity_id in zone_entity_ids: for zone_entity_id in zone_entity_ids:
try: try:
if zone(hass, zone_entity_id, entity_id): if zone(self._hass, zone_entity_id, entity_id):
entity_ok = True entity_ok = True
except ConditionErrorMessage as ex: except ConditionErrorMessage as ex:
errors.append( errors.append(

View File

@@ -13,7 +13,7 @@ import inspect
import logging import logging
import re import re
import sys import sys
from typing import TYPE_CHECKING, Any, Protocol, cast from typing import TYPE_CHECKING, Any, Protocol, TypedDict, Unpack, cast, overload
import voluptuous as vol import voluptuous as vol
@@ -298,7 +298,7 @@ class Condition(abc.ABC):
self._hass = hass self._hass = hass
@abc.abstractmethod @abc.abstractmethod
async def async_get_checker(self) -> ConditionCheckerType: async def async_get_checker(self) -> ConditionChecker:
"""Get the condition checker.""" """Get the condition checker."""
@@ -319,7 +319,23 @@ class ConditionConfig:
target: dict[str, Any] | None = None target: dict[str, Any] | None = None
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None] class ConditionCheckParams(TypedDict, total=False):
"""Condition check params."""
variables: TemplateVarsType
class ConditionChecker(Protocol):
"""Protocol for condition checker callable with typed kwargs."""
def __call__(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Check the condition."""
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool]
type ConditionCheckerTypeOptional = Callable[
[HomeAssistant, TemplateVarsType], bool | None
]
def condition_trace_append(variables: TemplateVarsType, path: str) -> TraceElement: def condition_trace_append(variables: TemplateVarsType, path: str) -> TraceElement:
@@ -374,7 +390,21 @@ def trace_condition(variables: TemplateVarsType) -> Generator[TraceElement]:
trace_stack_pop(trace_stack_cv) trace_stack_pop(trace_stack_cv)
def trace_condition_function(condition: ConditionCheckerType) -> ConditionCheckerType: @overload
def trace_condition_function(
condition: ConditionCheckerType,
) -> ConditionCheckerType: ...
@overload
def trace_condition_function(
condition: ConditionCheckerTypeOptional,
) -> ConditionCheckerTypeOptional: ...
def trace_condition_function(
condition: ConditionCheckerType | ConditionCheckerTypeOptional,
) -> ConditionCheckerType | ConditionCheckerTypeOptional:
"""Wrap a condition function to enable basic tracing.""" """Wrap a condition function to enable basic tracing."""
@ft.wraps(condition) @ft.wraps(condition)
@@ -420,10 +450,20 @@ async def _async_get_condition_platform(
) from None ) from None
async def _async_get_checker(condition: Condition) -> ConditionCheckerType:
new_checker = await condition.async_get_checker()
@trace_condition_function
def checker(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
return new_checker(variables=variables)
return checker
async def async_from_config( async def async_from_config(
hass: HomeAssistant, hass: HomeAssistant,
config: ConfigType, config: ConfigType,
) -> ConditionCheckerType: ) -> ConditionCheckerTypeOptional:
"""Turn a condition configuration into a method. """Turn a condition configuration into a method.
Should be run on the event loop. Should be run on the event loop.
@@ -466,7 +506,7 @@ async def async_from_config(
target=config.get(CONF_TARGET), target=config.get(CONF_TARGET),
), ),
) )
return await condition.async_get_checker() return await _async_get_checker(condition)
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT): for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
factory = getattr(sys.modules[__name__], fmt.format(condition_key), None) factory = getattr(sys.modules[__name__], fmt.format(condition_key), None)
@@ -1131,7 +1171,7 @@ async def async_conditions_from_config(
name: str, name: str,
) -> Callable[[TemplateVarsType], bool]: ) -> Callable[[TemplateVarsType], bool]:
"""AND all conditions.""" """AND all conditions."""
checks: list[ConditionCheckerType] = [ checks = [
await async_from_config(hass, condition_config) await async_from_config(hass, condition_config)
for condition_config in condition_configs for condition_config in condition_configs
] ]
@@ -1330,7 +1370,6 @@ async def async_get_all_descriptions(
continue continue
description = {"fields": yaml_description.get("fields", {})} description = {"fields": yaml_description.get("fields", {})}
if (target := yaml_description.get("target")) is not None: if (target := yaml_description.get("target")) is not None:
description["target"] = target description["target"] = target

View File

@@ -86,7 +86,7 @@ from homeassistant.util.hass_dict import HassKey
from homeassistant.util.signal_type import SignalType, SignalTypeFormat from homeassistant.util.signal_type import SignalType, SignalTypeFormat
from . import condition, config_validation as cv, service, template from . import condition, config_validation as cv, service, template
from .condition import ConditionCheckerType, trace_condition_function from .condition import ConditionCheckerTypeOptional, trace_condition_function
from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal
from .event import async_call_later, async_track_template from .event import async_call_later, async_track_template
from .script_variables import ScriptRunVariables, ScriptVariables from .script_variables import ScriptRunVariables, ScriptVariables
@@ -675,12 +675,14 @@ class _ScriptRun:
### Condition actions ### ### Condition actions ###
async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType: async def _async_get_condition(
self, config: ConfigType
) -> ConditionCheckerTypeOptional:
return await self._script._async_get_condition(config) # noqa: SLF001 return await self._script._async_get_condition(config) # noqa: SLF001
def _test_conditions( def _test_conditions(
self, self,
conditions: list[ConditionCheckerType], conditions: list[ConditionCheckerTypeOptional],
name: str, name: str,
condition_path: str | None = None, condition_path: str | None = None,
) -> bool | None: ) -> bool | None:
@@ -1404,12 +1406,12 @@ def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:
class _ChooseData(TypedDict): class _ChooseData(TypedDict):
choices: list[tuple[list[ConditionCheckerType], Script]] choices: list[tuple[list[ConditionCheckerTypeOptional], Script]]
default: Script | None default: Script | None
class _IfData(TypedDict): class _IfData(TypedDict):
if_conditions: list[ConditionCheckerType] if_conditions: list[ConditionCheckerTypeOptional]
if_then: Script if_then: Script
if_else: Script | None if_else: Script | None
@@ -1486,7 +1488,9 @@ class Script:
self._max_exceeded = max_exceeded self._max_exceeded = max_exceeded
if script_mode == SCRIPT_MODE_QUEUED: if script_mode == SCRIPT_MODE_QUEUED:
self._queue_lck = asyncio.Lock() self._queue_lck = asyncio.Lock()
self._config_cache: dict[frozenset[tuple[str, str]], ConditionCheckerType] = {} self._config_cache: dict[
frozenset[tuple[str, str]], ConditionCheckerTypeOptional
] = {}
self._repeat_script: dict[int, Script] = {} self._repeat_script: dict[int, Script] = {}
self._choose_data: dict[int, _ChooseData] = {} self._choose_data: dict[int, _ChooseData] = {}
self._if_data: dict[int, _IfData] = {} self._if_data: dict[int, _IfData] = {}
@@ -1857,7 +1861,9 @@ class Script:
return return
await asyncio.shield(create_eager_task(self._async_stop(aws, update_state))) await asyncio.shield(create_eager_task(self._async_stop(aws, update_state)))
async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType: async def _async_get_condition(
self, config: ConfigType
) -> ConditionCheckerTypeOptional:
config_cache_key = frozenset((k, str(v)) for k, v in config.items()) config_cache_key = frozenset((k, str(v)) for k, v in config.items())
if not (cond := self._config_cache.get(config_cache_key)): if not (cond := self._config_cache.get(config_cache_key)):
cond = await condition.async_from_config(self._hass, config) cond = await condition.async_from_config(self._hass, config)

View File

@@ -36,7 +36,7 @@ from homeassistant.helpers import (
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
from homeassistant.helpers.condition import ( from homeassistant.helpers.condition import (
Condition, Condition,
ConditionCheckerType, ConditionChecker,
async_validate_condition_config, async_validate_condition_config,
) )
from homeassistant.helpers.template import Template from homeassistant.helpers.template import Template
@@ -2126,16 +2126,16 @@ async def test_platform_multiple_conditions(hass: HomeAssistant) -> None:
class MockCondition1(MockCondition): class MockCondition1(MockCondition):
"""Mock condition 1.""" """Mock condition 1."""
async def async_get_checker(self) -> ConditionCheckerType: async def async_get_checker(self) -> ConditionChecker:
"""Evaluate state based on configuration.""" """Evaluate state based on configuration."""
return lambda hass, vars: True return lambda **kwargs: True
class MockCondition2(MockCondition): class MockCondition2(MockCondition):
"""Mock condition 2.""" """Mock condition 2."""
async def async_get_checker(self) -> ConditionCheckerType: async def async_get_checker(self) -> ConditionChecker:
"""Evaluate state based on configuration.""" """Evaluate state based on configuration."""
return lambda hass, vars: False return lambda **kwargs: False
async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]: async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]:
return { return {