1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-20 02:48:57 +00:00

Modernise condition checker in helper (#159159)

This commit is contained in:
Artur Pragacz
2025-12-16 10:46:10 +01:00
committed by GitHub
parent dbfdaf6a2e
commit 2d33a720f7
7 changed files with 99 additions and 54 deletions

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol
from typing import Any, Protocol
import voluptuous as vol
@@ -11,18 +11,15 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.condition import (
Condition,
ConditionChecker,
ConditionCheckerType,
ConditionConfig,
trace_condition_function,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from . import DeviceAutomationType, async_get_device_automation_platform
from .helpers import async_validate_device_automation_config
if TYPE_CHECKING:
from homeassistant.helpers import condition
class DeviceAutomationConditionProtocol(Protocol):
"""Define the format of device_condition modules.
@@ -90,15 +87,21 @@ class DeviceCondition(Condition):
assert config.options is not None
self._config = config.options
async def async_get_checker(self) -> condition.ConditionCheckerType:
async def async_get_checker(self) -> ConditionChecker:
"""Test a device condition."""
platform = await async_get_device_automation_platform(
self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
return trace_condition_function(
platform.async_condition_from_config(self._hass, self._config)
platform_checker = platform.async_condition_from_config(
self._hass, self._config
)
def checker(variables: TemplateVarsType = None, **kwargs: Any) -> bool:
result = platform_checker(self._hass, variables)
return result is not False
return checker
CONDITIONS: dict[str, type[Condition]] = {
"_device": DeviceCondition,

View File

@@ -1,7 +1,7 @@
"""Provides conditions for lights."""
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Final, override
from typing import TYPE_CHECKING, Any, Final, Unpack, override
import voluptuous as vol
@@ -10,11 +10,11 @@ from homeassistant.core import HomeAssistant, split_entity_id
from homeassistant.helpers import config_validation as cv, target
from homeassistant.helpers.condition import (
Condition,
ConditionCheckerType,
ConditionChecker,
ConditionCheckParams,
ConditionConfig,
trace_condition_function,
)
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN
@@ -61,7 +61,7 @@ class StateConditionBase(Condition):
self._state = state
@override
async def async_get_checker(self) -> ConditionCheckerType:
async def async_get_checker(self) -> ConditionChecker:
"""Get the condition checker."""
def check_any_match_state(states: list[str]) -> bool:
@@ -78,12 +78,11 @@ class StateConditionBase(Condition):
elif self._behavior == BEHAVIOR_ALL:
matcher = check_all_match_state
@trace_condition_function
def test_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
def test_state(**kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Test state condition."""
target_selection = target.TargetSelection(self._target)
targeted_entities = target.async_extract_referenced_entity_ids(
hass, target_selection, expand_group=False
self._hass, target_selection, expand_group=False
)
referenced_entity_ids = targeted_entities.referenced.union(
targeted_entities.indirectly_referenced
@@ -96,7 +95,7 @@ class StateConditionBase(Condition):
light_entity_states = [
state.state
for entity_id in light_entity_ids
if (state := hass.states.get(entity_id))
if (state := self._hass.states.get(entity_id))
and state.state in STATE_CONDITION_VALID_STATES
]
return matcher(light_entity_states)

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from datetime import datetime, timedelta
from typing import Any, cast
from typing import Any, Unpack, cast
import voluptuous as vol
@@ -13,14 +13,14 @@ from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
from homeassistant.helpers.condition import (
Condition,
ConditionCheckerType,
ConditionChecker,
ConditionCheckParams,
ConditionConfig,
condition_trace_set_result,
condition_trace_update_result,
trace_condition_function,
)
from homeassistant.helpers.sun import get_astral_event_date
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import dt as dt_util
_OPTIONS_SCHEMA_DICT: dict[vol.Marker, Any] = {
@@ -154,17 +154,16 @@ class SunCondition(Condition):
assert config.options is not None
self._options = config.options
async def async_get_checker(self) -> ConditionCheckerType:
async def async_get_checker(self) -> ConditionChecker:
"""Wrap action method with sun based condition."""
before = self._options.get("before")
after = self._options.get("after")
before_offset = self._options.get("before_offset")
after_offset = self._options.get("after_offset")
@trace_condition_function
def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
def sun_if(**kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Validate time based if-condition."""
return sun(hass, before, after, before_offset, after_offset)
return sun(self._hass, before, after, before_offset, after_offset)
return sun_if

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, cast
from typing import Any, Unpack, cast
import voluptuous as vol
@@ -22,11 +22,11 @@ from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
from homeassistant.helpers.condition import (
Condition,
ConditionCheckerType,
ConditionChecker,
ConditionCheckParams,
ConditionConfig,
trace_condition_function,
)
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from homeassistant.helpers.typing import ConfigType
from . import in_zone
@@ -118,13 +118,12 @@ class ZoneCondition(Condition):
assert config.options is not None
self._options = config.options
async def async_get_checker(self) -> ConditionCheckerType:
async def async_get_checker(self) -> ConditionChecker:
"""Wrap action method with zone based condition."""
entity_ids = self._options.get(CONF_ENTITY_ID, [])
zone_entity_ids = self._options.get(CONF_ZONE, [])
@trace_condition_function
def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
def if_in_zone(**kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Test if condition."""
errors = []
@@ -133,7 +132,7 @@ class ZoneCondition(Condition):
entity_ok = False
for zone_entity_id in zone_entity_ids:
try:
if zone(hass, zone_entity_id, entity_id):
if zone(self._hass, zone_entity_id, entity_id):
entity_ok = True
except ConditionErrorMessage as ex:
errors.append(

View File

@@ -13,7 +13,7 @@ import inspect
import logging
import re
import sys
from typing import TYPE_CHECKING, Any, Protocol, cast
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, Unpack, cast, overload
import voluptuous as vol
@@ -298,7 +298,7 @@ class Condition(abc.ABC):
self._hass = hass
@abc.abstractmethod
async def async_get_checker(self) -> ConditionCheckerType:
async def async_get_checker(self) -> ConditionChecker:
"""Get the condition checker."""
@@ -319,7 +319,23 @@ class ConditionConfig:
target: dict[str, Any] | None = None
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None]
class ConditionCheckParams(TypedDict, total=False):
"""Condition check params."""
variables: TemplateVarsType
class ConditionChecker(Protocol):
"""Protocol for condition checker callable with typed kwargs."""
def __call__(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Check the condition."""
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool]
type ConditionCheckerTypeOptional = Callable[
[HomeAssistant, TemplateVarsType], bool | None
]
def condition_trace_append(variables: TemplateVarsType, path: str) -> TraceElement:
@@ -374,7 +390,21 @@ def trace_condition(variables: TemplateVarsType) -> Generator[TraceElement]:
trace_stack_pop(trace_stack_cv)
def trace_condition_function(condition: ConditionCheckerType) -> ConditionCheckerType:
@overload
def trace_condition_function(
condition: ConditionCheckerType,
) -> ConditionCheckerType: ...
@overload
def trace_condition_function(
condition: ConditionCheckerTypeOptional,
) -> ConditionCheckerTypeOptional: ...
def trace_condition_function(
condition: ConditionCheckerType | ConditionCheckerTypeOptional,
) -> ConditionCheckerType | ConditionCheckerTypeOptional:
"""Wrap a condition function to enable basic tracing."""
@ft.wraps(condition)
@@ -420,10 +450,20 @@ async def _async_get_condition_platform(
) from None
async def _async_get_checker(condition: Condition) -> ConditionCheckerType:
new_checker = await condition.async_get_checker()
@trace_condition_function
def checker(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
return new_checker(variables=variables)
return checker
async def async_from_config(
hass: HomeAssistant,
config: ConfigType,
) -> ConditionCheckerType:
) -> ConditionCheckerTypeOptional:
"""Turn a condition configuration into a method.
Should be run on the event loop.
@@ -466,7 +506,7 @@ async def async_from_config(
target=config.get(CONF_TARGET),
),
)
return await condition.async_get_checker()
return await _async_get_checker(condition)
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
factory = getattr(sys.modules[__name__], fmt.format(condition_key), None)
@@ -1131,7 +1171,7 @@ async def async_conditions_from_config(
name: str,
) -> Callable[[TemplateVarsType], bool]:
"""AND all conditions."""
checks: list[ConditionCheckerType] = [
checks = [
await async_from_config(hass, condition_config)
for condition_config in condition_configs
]
@@ -1330,7 +1370,6 @@ async def async_get_all_descriptions(
continue
description = {"fields": yaml_description.get("fields", {})}
if (target := yaml_description.get("target")) is not None:
description["target"] = target

View File

@@ -86,7 +86,7 @@ from homeassistant.util.hass_dict import HassKey
from homeassistant.util.signal_type import SignalType, SignalTypeFormat
from . import condition, config_validation as cv, service, template
from .condition import ConditionCheckerType, trace_condition_function
from .condition import ConditionCheckerTypeOptional, trace_condition_function
from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal
from .event import async_call_later, async_track_template
from .script_variables import ScriptRunVariables, ScriptVariables
@@ -675,12 +675,14 @@ class _ScriptRun:
### Condition actions ###
async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
async def _async_get_condition(
self, config: ConfigType
) -> ConditionCheckerTypeOptional:
return await self._script._async_get_condition(config) # noqa: SLF001
def _test_conditions(
self,
conditions: list[ConditionCheckerType],
conditions: list[ConditionCheckerTypeOptional],
name: str,
condition_path: str | None = None,
) -> bool | None:
@@ -1404,12 +1406,12 @@ def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:
class _ChooseData(TypedDict):
choices: list[tuple[list[ConditionCheckerType], Script]]
choices: list[tuple[list[ConditionCheckerTypeOptional], Script]]
default: Script | None
class _IfData(TypedDict):
if_conditions: list[ConditionCheckerType]
if_conditions: list[ConditionCheckerTypeOptional]
if_then: Script
if_else: Script | None
@@ -1486,7 +1488,9 @@ class Script:
self._max_exceeded = max_exceeded
if script_mode == SCRIPT_MODE_QUEUED:
self._queue_lck = asyncio.Lock()
self._config_cache: dict[frozenset[tuple[str, str]], ConditionCheckerType] = {}
self._config_cache: dict[
frozenset[tuple[str, str]], ConditionCheckerTypeOptional
] = {}
self._repeat_script: dict[int, Script] = {}
self._choose_data: dict[int, _ChooseData] = {}
self._if_data: dict[int, _IfData] = {}
@@ -1857,7 +1861,9 @@ class Script:
return
await asyncio.shield(create_eager_task(self._async_stop(aws, update_state)))
async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
async def _async_get_condition(
self, config: ConfigType
) -> ConditionCheckerTypeOptional:
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
if not (cond := self._config_cache.get(config_cache_key)):
cond = await condition.async_from_config(self._hass, config)

View File

@@ -36,7 +36,7 @@ from homeassistant.helpers import (
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
from homeassistant.helpers.condition import (
Condition,
ConditionCheckerType,
ConditionChecker,
async_validate_condition_config,
)
from homeassistant.helpers.template import Template
@@ -2126,16 +2126,16 @@ async def test_platform_multiple_conditions(hass: HomeAssistant) -> None:
class MockCondition1(MockCondition):
"""Mock condition 1."""
async def async_get_checker(self) -> ConditionCheckerType:
async def async_get_checker(self) -> ConditionChecker:
"""Evaluate state based on configuration."""
return lambda hass, vars: True
return lambda **kwargs: True
class MockCondition2(MockCondition):
"""Mock condition 2."""
async def async_get_checker(self) -> ConditionCheckerType:
async def async_get_checker(self) -> ConditionChecker:
"""Evaluate state based on configuration."""
return lambda hass, vars: False
return lambda **kwargs: False
async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]:
return {