mirror of
https://github.com/home-assistant/core.git
synced 2025-12-20 02:48:57 +00:00
Modernise condition checker in helper (#159159)
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
from typing import Any, Protocol
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@@ -11,18 +11,15 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.condition import (
|
||||
Condition,
|
||||
ConditionChecker,
|
||||
ConditionCheckerType,
|
||||
ConditionConfig,
|
||||
trace_condition_function,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
|
||||
|
||||
from . import DeviceAutomationType, async_get_device_automation_platform
|
||||
from .helpers import async_validate_device_automation_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from homeassistant.helpers import condition
|
||||
|
||||
|
||||
class DeviceAutomationConditionProtocol(Protocol):
|
||||
"""Define the format of device_condition modules.
|
||||
@@ -90,15 +87,21 @@ class DeviceCondition(Condition):
|
||||
assert config.options is not None
|
||||
self._config = config.options
|
||||
|
||||
async def async_get_checker(self) -> condition.ConditionCheckerType:
|
||||
async def async_get_checker(self) -> ConditionChecker:
|
||||
"""Test a device condition."""
|
||||
platform = await async_get_device_automation_platform(
|
||||
self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
||||
)
|
||||
return trace_condition_function(
|
||||
platform.async_condition_from_config(self._hass, self._config)
|
||||
platform_checker = platform.async_condition_from_config(
|
||||
self._hass, self._config
|
||||
)
|
||||
|
||||
def checker(variables: TemplateVarsType = None, **kwargs: Any) -> bool:
|
||||
result = platform_checker(self._hass, variables)
|
||||
return result is not False
|
||||
|
||||
return checker
|
||||
|
||||
|
||||
CONDITIONS: dict[str, type[Condition]] = {
|
||||
"_device": DeviceCondition,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Provides conditions for lights."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Final, override
|
||||
from typing import TYPE_CHECKING, Any, Final, Unpack, override
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@@ -10,11 +10,11 @@ from homeassistant.core import HomeAssistant, split_entity_id
|
||||
from homeassistant.helpers import config_validation as cv, target
|
||||
from homeassistant.helpers.condition import (
|
||||
Condition,
|
||||
ConditionCheckerType,
|
||||
ConditionChecker,
|
||||
ConditionCheckParams,
|
||||
ConditionConfig,
|
||||
trace_condition_function,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
@@ -61,7 +61,7 @@ class StateConditionBase(Condition):
|
||||
self._state = state
|
||||
|
||||
@override
|
||||
async def async_get_checker(self) -> ConditionCheckerType:
|
||||
async def async_get_checker(self) -> ConditionChecker:
|
||||
"""Get the condition checker."""
|
||||
|
||||
def check_any_match_state(states: list[str]) -> bool:
|
||||
@@ -78,12 +78,11 @@ class StateConditionBase(Condition):
|
||||
elif self._behavior == BEHAVIOR_ALL:
|
||||
matcher = check_all_match_state
|
||||
|
||||
@trace_condition_function
|
||||
def test_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
|
||||
def test_state(**kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||
"""Test state condition."""
|
||||
target_selection = target.TargetSelection(self._target)
|
||||
targeted_entities = target.async_extract_referenced_entity_ids(
|
||||
hass, target_selection, expand_group=False
|
||||
self._hass, target_selection, expand_group=False
|
||||
)
|
||||
referenced_entity_ids = targeted_entities.referenced.union(
|
||||
targeted_entities.indirectly_referenced
|
||||
@@ -96,7 +95,7 @@ class StateConditionBase(Condition):
|
||||
light_entity_states = [
|
||||
state.state
|
||||
for entity_id in light_entity_ids
|
||||
if (state := hass.states.get(entity_id))
|
||||
if (state := self._hass.states.get(entity_id))
|
||||
and state.state in STATE_CONDITION_VALID_STATES
|
||||
]
|
||||
return matcher(light_entity_states)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, cast
|
||||
from typing import Any, Unpack, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@@ -13,14 +13,14 @@ from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
||||
from homeassistant.helpers.condition import (
|
||||
Condition,
|
||||
ConditionCheckerType,
|
||||
ConditionChecker,
|
||||
ConditionCheckParams,
|
||||
ConditionConfig,
|
||||
condition_trace_set_result,
|
||||
condition_trace_update_result,
|
||||
trace_condition_function,
|
||||
)
|
||||
from homeassistant.helpers.sun import get_astral_event_date
|
||||
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
_OPTIONS_SCHEMA_DICT: dict[vol.Marker, Any] = {
|
||||
@@ -154,17 +154,16 @@ class SunCondition(Condition):
|
||||
assert config.options is not None
|
||||
self._options = config.options
|
||||
|
||||
async def async_get_checker(self) -> ConditionCheckerType:
|
||||
async def async_get_checker(self) -> ConditionChecker:
|
||||
"""Wrap action method with sun based condition."""
|
||||
before = self._options.get("before")
|
||||
after = self._options.get("after")
|
||||
before_offset = self._options.get("before_offset")
|
||||
after_offset = self._options.get("after_offset")
|
||||
|
||||
@trace_condition_function
|
||||
def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
|
||||
def sun_if(**kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||
"""Validate time based if-condition."""
|
||||
return sun(hass, before, after, before_offset, after_offset)
|
||||
return sun(self._hass, before, after, before_offset, after_offset)
|
||||
|
||||
return sun_if
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
from typing import Any, Unpack, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@@ -22,11 +22,11 @@ from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
||||
from homeassistant.helpers.condition import (
|
||||
Condition,
|
||||
ConditionCheckerType,
|
||||
ConditionChecker,
|
||||
ConditionCheckParams,
|
||||
ConditionConfig,
|
||||
trace_condition_function,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import in_zone
|
||||
|
||||
@@ -118,13 +118,12 @@ class ZoneCondition(Condition):
|
||||
assert config.options is not None
|
||||
self._options = config.options
|
||||
|
||||
async def async_get_checker(self) -> ConditionCheckerType:
|
||||
async def async_get_checker(self) -> ConditionChecker:
|
||||
"""Wrap action method with zone based condition."""
|
||||
entity_ids = self._options.get(CONF_ENTITY_ID, [])
|
||||
zone_entity_ids = self._options.get(CONF_ZONE, [])
|
||||
|
||||
@trace_condition_function
|
||||
def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
|
||||
def if_in_zone(**kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||
"""Test if condition."""
|
||||
errors = []
|
||||
|
||||
@@ -133,7 +132,7 @@ class ZoneCondition(Condition):
|
||||
entity_ok = False
|
||||
for zone_entity_id in zone_entity_ids:
|
||||
try:
|
||||
if zone(hass, zone_entity_id, entity_id):
|
||||
if zone(self._hass, zone_entity_id, entity_id):
|
||||
entity_ok = True
|
||||
except ConditionErrorMessage as ex:
|
||||
errors.append(
|
||||
|
||||
@@ -13,7 +13,7 @@ import inspect
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, Unpack, cast, overload
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@@ -298,7 +298,7 @@ class Condition(abc.ABC):
|
||||
self._hass = hass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_get_checker(self) -> ConditionCheckerType:
|
||||
async def async_get_checker(self) -> ConditionChecker:
|
||||
"""Get the condition checker."""
|
||||
|
||||
|
||||
@@ -319,7 +319,23 @@ class ConditionConfig:
|
||||
target: dict[str, Any] | None = None
|
||||
|
||||
|
||||
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None]
|
||||
class ConditionCheckParams(TypedDict, total=False):
|
||||
"""Condition check params."""
|
||||
|
||||
variables: TemplateVarsType
|
||||
|
||||
|
||||
class ConditionChecker(Protocol):
|
||||
"""Protocol for condition checker callable with typed kwargs."""
|
||||
|
||||
def __call__(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||
"""Check the condition."""
|
||||
|
||||
|
||||
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool]
|
||||
type ConditionCheckerTypeOptional = Callable[
|
||||
[HomeAssistant, TemplateVarsType], bool | None
|
||||
]
|
||||
|
||||
|
||||
def condition_trace_append(variables: TemplateVarsType, path: str) -> TraceElement:
|
||||
@@ -374,7 +390,21 @@ def trace_condition(variables: TemplateVarsType) -> Generator[TraceElement]:
|
||||
trace_stack_pop(trace_stack_cv)
|
||||
|
||||
|
||||
def trace_condition_function(condition: ConditionCheckerType) -> ConditionCheckerType:
|
||||
@overload
|
||||
def trace_condition_function(
|
||||
condition: ConditionCheckerType,
|
||||
) -> ConditionCheckerType: ...
|
||||
|
||||
|
||||
@overload
|
||||
def trace_condition_function(
|
||||
condition: ConditionCheckerTypeOptional,
|
||||
) -> ConditionCheckerTypeOptional: ...
|
||||
|
||||
|
||||
def trace_condition_function(
|
||||
condition: ConditionCheckerType | ConditionCheckerTypeOptional,
|
||||
) -> ConditionCheckerType | ConditionCheckerTypeOptional:
|
||||
"""Wrap a condition function to enable basic tracing."""
|
||||
|
||||
@ft.wraps(condition)
|
||||
@@ -420,10 +450,20 @@ async def _async_get_condition_platform(
|
||||
) from None
|
||||
|
||||
|
||||
async def _async_get_checker(condition: Condition) -> ConditionCheckerType:
|
||||
new_checker = await condition.async_get_checker()
|
||||
|
||||
@trace_condition_function
|
||||
def checker(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
|
||||
return new_checker(variables=variables)
|
||||
|
||||
return checker
|
||||
|
||||
|
||||
async def async_from_config(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
) -> ConditionCheckerType:
|
||||
) -> ConditionCheckerTypeOptional:
|
||||
"""Turn a condition configuration into a method.
|
||||
|
||||
Should be run on the event loop.
|
||||
@@ -466,7 +506,7 @@ async def async_from_config(
|
||||
target=config.get(CONF_TARGET),
|
||||
),
|
||||
)
|
||||
return await condition.async_get_checker()
|
||||
return await _async_get_checker(condition)
|
||||
|
||||
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
|
||||
factory = getattr(sys.modules[__name__], fmt.format(condition_key), None)
|
||||
@@ -1131,7 +1171,7 @@ async def async_conditions_from_config(
|
||||
name: str,
|
||||
) -> Callable[[TemplateVarsType], bool]:
|
||||
"""AND all conditions."""
|
||||
checks: list[ConditionCheckerType] = [
|
||||
checks = [
|
||||
await async_from_config(hass, condition_config)
|
||||
for condition_config in condition_configs
|
||||
]
|
||||
@@ -1330,7 +1370,6 @@ async def async_get_all_descriptions(
|
||||
continue
|
||||
|
||||
description = {"fields": yaml_description.get("fields", {})}
|
||||
|
||||
if (target := yaml_description.get("target")) is not None:
|
||||
description["target"] = target
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.signal_type import SignalType, SignalTypeFormat
|
||||
|
||||
from . import condition, config_validation as cv, service, template
|
||||
from .condition import ConditionCheckerType, trace_condition_function
|
||||
from .condition import ConditionCheckerTypeOptional, trace_condition_function
|
||||
from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal
|
||||
from .event import async_call_later, async_track_template
|
||||
from .script_variables import ScriptRunVariables, ScriptVariables
|
||||
@@ -675,12 +675,14 @@ class _ScriptRun:
|
||||
|
||||
### Condition actions ###
|
||||
|
||||
async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
|
||||
async def _async_get_condition(
|
||||
self, config: ConfigType
|
||||
) -> ConditionCheckerTypeOptional:
|
||||
return await self._script._async_get_condition(config) # noqa: SLF001
|
||||
|
||||
def _test_conditions(
|
||||
self,
|
||||
conditions: list[ConditionCheckerType],
|
||||
conditions: list[ConditionCheckerTypeOptional],
|
||||
name: str,
|
||||
condition_path: str | None = None,
|
||||
) -> bool | None:
|
||||
@@ -1404,12 +1406,12 @@ def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:
|
||||
|
||||
|
||||
class _ChooseData(TypedDict):
|
||||
choices: list[tuple[list[ConditionCheckerType], Script]]
|
||||
choices: list[tuple[list[ConditionCheckerTypeOptional], Script]]
|
||||
default: Script | None
|
||||
|
||||
|
||||
class _IfData(TypedDict):
|
||||
if_conditions: list[ConditionCheckerType]
|
||||
if_conditions: list[ConditionCheckerTypeOptional]
|
||||
if_then: Script
|
||||
if_else: Script | None
|
||||
|
||||
@@ -1486,7 +1488,9 @@ class Script:
|
||||
self._max_exceeded = max_exceeded
|
||||
if script_mode == SCRIPT_MODE_QUEUED:
|
||||
self._queue_lck = asyncio.Lock()
|
||||
self._config_cache: dict[frozenset[tuple[str, str]], ConditionCheckerType] = {}
|
||||
self._config_cache: dict[
|
||||
frozenset[tuple[str, str]], ConditionCheckerTypeOptional
|
||||
] = {}
|
||||
self._repeat_script: dict[int, Script] = {}
|
||||
self._choose_data: dict[int, _ChooseData] = {}
|
||||
self._if_data: dict[int, _IfData] = {}
|
||||
@@ -1857,7 +1861,9 @@ class Script:
|
||||
return
|
||||
await asyncio.shield(create_eager_task(self._async_stop(aws, update_state)))
|
||||
|
||||
async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
|
||||
async def _async_get_condition(
|
||||
self, config: ConfigType
|
||||
) -> ConditionCheckerTypeOptional:
|
||||
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
|
||||
if not (cond := self._config_cache.get(config_cache_key)):
|
||||
cond = await condition.async_from_config(self._hass, config)
|
||||
|
||||
@@ -36,7 +36,7 @@ from homeassistant.helpers import (
|
||||
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
||||
from homeassistant.helpers.condition import (
|
||||
Condition,
|
||||
ConditionCheckerType,
|
||||
ConditionChecker,
|
||||
async_validate_condition_config,
|
||||
)
|
||||
from homeassistant.helpers.template import Template
|
||||
@@ -2126,16 +2126,16 @@ async def test_platform_multiple_conditions(hass: HomeAssistant) -> None:
|
||||
class MockCondition1(MockCondition):
|
||||
"""Mock condition 1."""
|
||||
|
||||
async def async_get_checker(self) -> ConditionCheckerType:
|
||||
async def async_get_checker(self) -> ConditionChecker:
|
||||
"""Evaluate state based on configuration."""
|
||||
return lambda hass, vars: True
|
||||
return lambda **kwargs: True
|
||||
|
||||
class MockCondition2(MockCondition):
|
||||
"""Mock condition 2."""
|
||||
|
||||
async def async_get_checker(self) -> ConditionCheckerType:
|
||||
async def async_get_checker(self) -> ConditionChecker:
|
||||
"""Evaluate state based on configuration."""
|
||||
return lambda hass, vars: False
|
||||
return lambda **kwargs: False
|
||||
|
||||
async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]:
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user