mirror of
https://github.com/home-assistant/core.git
synced 2025-12-23 20:39:01 +00:00
Modernise condition checker in helper (#159159)
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Protocol
|
from typing import Any, Protocol
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@@ -11,18 +11,15 @@ from homeassistant.core import HomeAssistant
|
|||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.helpers.condition import (
|
from homeassistant.helpers.condition import (
|
||||||
Condition,
|
Condition,
|
||||||
|
ConditionChecker,
|
||||||
ConditionCheckerType,
|
ConditionCheckerType,
|
||||||
ConditionConfig,
|
ConditionConfig,
|
||||||
trace_condition_function,
|
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
|
||||||
|
|
||||||
from . import DeviceAutomationType, async_get_device_automation_platform
|
from . import DeviceAutomationType, async_get_device_automation_platform
|
||||||
from .helpers import async_validate_device_automation_config
|
from .helpers import async_validate_device_automation_config
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from homeassistant.helpers import condition
|
|
||||||
|
|
||||||
|
|
||||||
class DeviceAutomationConditionProtocol(Protocol):
|
class DeviceAutomationConditionProtocol(Protocol):
|
||||||
"""Define the format of device_condition modules.
|
"""Define the format of device_condition modules.
|
||||||
@@ -90,15 +87,21 @@ class DeviceCondition(Condition):
|
|||||||
assert config.options is not None
|
assert config.options is not None
|
||||||
self._config = config.options
|
self._config = config.options
|
||||||
|
|
||||||
async def async_get_checker(self) -> condition.ConditionCheckerType:
|
async def async_get_checker(self) -> ConditionChecker:
|
||||||
"""Test a device condition."""
|
"""Test a device condition."""
|
||||||
platform = await async_get_device_automation_platform(
|
platform = await async_get_device_automation_platform(
|
||||||
self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
||||||
)
|
)
|
||||||
return trace_condition_function(
|
platform_checker = platform.async_condition_from_config(
|
||||||
platform.async_condition_from_config(self._hass, self._config)
|
self._hass, self._config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def checker(variables: TemplateVarsType = None, **kwargs: Any) -> bool:
|
||||||
|
result = platform_checker(self._hass, variables)
|
||||||
|
return result is not False
|
||||||
|
|
||||||
|
return checker
|
||||||
|
|
||||||
|
|
||||||
CONDITIONS: dict[str, type[Condition]] = {
|
CONDITIONS: dict[str, type[Condition]] = {
|
||||||
"_device": DeviceCondition,
|
"_device": DeviceCondition,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Provides conditions for lights."""
|
"""Provides conditions for lights."""
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import TYPE_CHECKING, Any, Final, override
|
from typing import TYPE_CHECKING, Any, Final, Unpack, override
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@@ -10,11 +10,11 @@ from homeassistant.core import HomeAssistant, split_entity_id
|
|||||||
from homeassistant.helpers import config_validation as cv, target
|
from homeassistant.helpers import config_validation as cv, target
|
||||||
from homeassistant.helpers.condition import (
|
from homeassistant.helpers.condition import (
|
||||||
Condition,
|
Condition,
|
||||||
ConditionCheckerType,
|
ConditionChecker,
|
||||||
|
ConditionCheckParams,
|
||||||
ConditionConfig,
|
ConditionConfig,
|
||||||
trace_condition_function,
|
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ class StateConditionBase(Condition):
|
|||||||
self._state = state
|
self._state = state
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def async_get_checker(self) -> ConditionCheckerType:
|
async def async_get_checker(self) -> ConditionChecker:
|
||||||
"""Get the condition checker."""
|
"""Get the condition checker."""
|
||||||
|
|
||||||
def check_any_match_state(states: list[str]) -> bool:
|
def check_any_match_state(states: list[str]) -> bool:
|
||||||
@@ -78,12 +78,11 @@ class StateConditionBase(Condition):
|
|||||||
elif self._behavior == BEHAVIOR_ALL:
|
elif self._behavior == BEHAVIOR_ALL:
|
||||||
matcher = check_all_match_state
|
matcher = check_all_match_state
|
||||||
|
|
||||||
@trace_condition_function
|
def test_state(**kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||||
def test_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
|
|
||||||
"""Test state condition."""
|
"""Test state condition."""
|
||||||
target_selection = target.TargetSelection(self._target)
|
target_selection = target.TargetSelection(self._target)
|
||||||
targeted_entities = target.async_extract_referenced_entity_ids(
|
targeted_entities = target.async_extract_referenced_entity_ids(
|
||||||
hass, target_selection, expand_group=False
|
self._hass, target_selection, expand_group=False
|
||||||
)
|
)
|
||||||
referenced_entity_ids = targeted_entities.referenced.union(
|
referenced_entity_ids = targeted_entities.referenced.union(
|
||||||
targeted_entities.indirectly_referenced
|
targeted_entities.indirectly_referenced
|
||||||
@@ -96,7 +95,7 @@ class StateConditionBase(Condition):
|
|||||||
light_entity_states = [
|
light_entity_states = [
|
||||||
state.state
|
state.state
|
||||||
for entity_id in light_entity_ids
|
for entity_id in light_entity_ids
|
||||||
if (state := hass.states.get(entity_id))
|
if (state := self._hass.states.get(entity_id))
|
||||||
and state.state in STATE_CONDITION_VALID_STATES
|
and state.state in STATE_CONDITION_VALID_STATES
|
||||||
]
|
]
|
||||||
return matcher(light_entity_states)
|
return matcher(light_entity_states)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, cast
|
from typing import Any, Unpack, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@@ -13,14 +13,14 @@ from homeassistant.helpers import config_validation as cv
|
|||||||
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
||||||
from homeassistant.helpers.condition import (
|
from homeassistant.helpers.condition import (
|
||||||
Condition,
|
Condition,
|
||||||
ConditionCheckerType,
|
ConditionChecker,
|
||||||
|
ConditionCheckParams,
|
||||||
ConditionConfig,
|
ConditionConfig,
|
||||||
condition_trace_set_result,
|
condition_trace_set_result,
|
||||||
condition_trace_update_result,
|
condition_trace_update_result,
|
||||||
trace_condition_function,
|
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.sun import get_astral_event_date
|
from homeassistant.helpers.sun import get_astral_event_date
|
||||||
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
_OPTIONS_SCHEMA_DICT: dict[vol.Marker, Any] = {
|
_OPTIONS_SCHEMA_DICT: dict[vol.Marker, Any] = {
|
||||||
@@ -154,17 +154,16 @@ class SunCondition(Condition):
|
|||||||
assert config.options is not None
|
assert config.options is not None
|
||||||
self._options = config.options
|
self._options = config.options
|
||||||
|
|
||||||
async def async_get_checker(self) -> ConditionCheckerType:
|
async def async_get_checker(self) -> ConditionChecker:
|
||||||
"""Wrap action method with sun based condition."""
|
"""Wrap action method with sun based condition."""
|
||||||
before = self._options.get("before")
|
before = self._options.get("before")
|
||||||
after = self._options.get("after")
|
after = self._options.get("after")
|
||||||
before_offset = self._options.get("before_offset")
|
before_offset = self._options.get("before_offset")
|
||||||
after_offset = self._options.get("after_offset")
|
after_offset = self._options.get("after_offset")
|
||||||
|
|
||||||
@trace_condition_function
|
def sun_if(**kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||||
def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
|
|
||||||
"""Validate time based if-condition."""
|
"""Validate time based if-condition."""
|
||||||
return sun(hass, before, after, before_offset, after_offset)
|
return sun(self._hass, before, after, before_offset, after_offset)
|
||||||
|
|
||||||
return sun_if
|
return sun_if
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, cast
|
from typing import Any, Unpack, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@@ -22,11 +22,11 @@ from homeassistant.helpers import config_validation as cv
|
|||||||
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
||||||
from homeassistant.helpers.condition import (
|
from homeassistant.helpers.condition import (
|
||||||
Condition,
|
Condition,
|
||||||
ConditionCheckerType,
|
ConditionChecker,
|
||||||
|
ConditionCheckParams,
|
||||||
ConditionConfig,
|
ConditionConfig,
|
||||||
trace_condition_function,
|
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from . import in_zone
|
from . import in_zone
|
||||||
|
|
||||||
@@ -118,13 +118,12 @@ class ZoneCondition(Condition):
|
|||||||
assert config.options is not None
|
assert config.options is not None
|
||||||
self._options = config.options
|
self._options = config.options
|
||||||
|
|
||||||
async def async_get_checker(self) -> ConditionCheckerType:
|
async def async_get_checker(self) -> ConditionChecker:
|
||||||
"""Wrap action method with zone based condition."""
|
"""Wrap action method with zone based condition."""
|
||||||
entity_ids = self._options.get(CONF_ENTITY_ID, [])
|
entity_ids = self._options.get(CONF_ENTITY_ID, [])
|
||||||
zone_entity_ids = self._options.get(CONF_ZONE, [])
|
zone_entity_ids = self._options.get(CONF_ZONE, [])
|
||||||
|
|
||||||
@trace_condition_function
|
def if_in_zone(**kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||||
def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
|
|
||||||
"""Test if condition."""
|
"""Test if condition."""
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
@@ -133,7 +132,7 @@ class ZoneCondition(Condition):
|
|||||||
entity_ok = False
|
entity_ok = False
|
||||||
for zone_entity_id in zone_entity_ids:
|
for zone_entity_id in zone_entity_ids:
|
||||||
try:
|
try:
|
||||||
if zone(hass, zone_entity_id, entity_id):
|
if zone(self._hass, zone_entity_id, entity_id):
|
||||||
entity_ok = True
|
entity_ok = True
|
||||||
except ConditionErrorMessage as ex:
|
except ConditionErrorMessage as ex:
|
||||||
errors.append(
|
errors.append(
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from typing import TYPE_CHECKING, Any, Protocol, cast
|
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, Unpack, cast, overload
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@@ -298,7 +298,7 @@ class Condition(abc.ABC):
|
|||||||
self._hass = hass
|
self._hass = hass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def async_get_checker(self) -> ConditionCheckerType:
|
async def async_get_checker(self) -> ConditionChecker:
|
||||||
"""Get the condition checker."""
|
"""Get the condition checker."""
|
||||||
|
|
||||||
|
|
||||||
@@ -319,7 +319,23 @@ class ConditionConfig:
|
|||||||
target: dict[str, Any] | None = None
|
target: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool | None]
|
class ConditionCheckParams(TypedDict, total=False):
|
||||||
|
"""Condition check params."""
|
||||||
|
|
||||||
|
variables: TemplateVarsType
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionChecker(Protocol):
|
||||||
|
"""Protocol for condition checker callable with typed kwargs."""
|
||||||
|
|
||||||
|
def __call__(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||||
|
"""Check the condition."""
|
||||||
|
|
||||||
|
|
||||||
|
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool]
|
||||||
|
type ConditionCheckerTypeOptional = Callable[
|
||||||
|
[HomeAssistant, TemplateVarsType], bool | None
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def condition_trace_append(variables: TemplateVarsType, path: str) -> TraceElement:
|
def condition_trace_append(variables: TemplateVarsType, path: str) -> TraceElement:
|
||||||
@@ -374,7 +390,21 @@ def trace_condition(variables: TemplateVarsType) -> Generator[TraceElement]:
|
|||||||
trace_stack_pop(trace_stack_cv)
|
trace_stack_pop(trace_stack_cv)
|
||||||
|
|
||||||
|
|
||||||
def trace_condition_function(condition: ConditionCheckerType) -> ConditionCheckerType:
|
@overload
|
||||||
|
def trace_condition_function(
|
||||||
|
condition: ConditionCheckerType,
|
||||||
|
) -> ConditionCheckerType: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def trace_condition_function(
|
||||||
|
condition: ConditionCheckerTypeOptional,
|
||||||
|
) -> ConditionCheckerTypeOptional: ...
|
||||||
|
|
||||||
|
|
||||||
|
def trace_condition_function(
|
||||||
|
condition: ConditionCheckerType | ConditionCheckerTypeOptional,
|
||||||
|
) -> ConditionCheckerType | ConditionCheckerTypeOptional:
|
||||||
"""Wrap a condition function to enable basic tracing."""
|
"""Wrap a condition function to enable basic tracing."""
|
||||||
|
|
||||||
@ft.wraps(condition)
|
@ft.wraps(condition)
|
||||||
@@ -420,10 +450,20 @@ async def _async_get_condition_platform(
|
|||||||
) from None
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_get_checker(condition: Condition) -> ConditionCheckerType:
|
||||||
|
new_checker = await condition.async_get_checker()
|
||||||
|
|
||||||
|
@trace_condition_function
|
||||||
|
def checker(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
|
||||||
|
return new_checker(variables=variables)
|
||||||
|
|
||||||
|
return checker
|
||||||
|
|
||||||
|
|
||||||
async def async_from_config(
|
async def async_from_config(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
config: ConfigType,
|
config: ConfigType,
|
||||||
) -> ConditionCheckerType:
|
) -> ConditionCheckerTypeOptional:
|
||||||
"""Turn a condition configuration into a method.
|
"""Turn a condition configuration into a method.
|
||||||
|
|
||||||
Should be run on the event loop.
|
Should be run on the event loop.
|
||||||
@@ -466,7 +506,7 @@ async def async_from_config(
|
|||||||
target=config.get(CONF_TARGET),
|
target=config.get(CONF_TARGET),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await condition.async_get_checker()
|
return await _async_get_checker(condition)
|
||||||
|
|
||||||
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
|
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
|
||||||
factory = getattr(sys.modules[__name__], fmt.format(condition_key), None)
|
factory = getattr(sys.modules[__name__], fmt.format(condition_key), None)
|
||||||
@@ -1131,7 +1171,7 @@ async def async_conditions_from_config(
|
|||||||
name: str,
|
name: str,
|
||||||
) -> Callable[[TemplateVarsType], bool]:
|
) -> Callable[[TemplateVarsType], bool]:
|
||||||
"""AND all conditions."""
|
"""AND all conditions."""
|
||||||
checks: list[ConditionCheckerType] = [
|
checks = [
|
||||||
await async_from_config(hass, condition_config)
|
await async_from_config(hass, condition_config)
|
||||||
for condition_config in condition_configs
|
for condition_config in condition_configs
|
||||||
]
|
]
|
||||||
@@ -1330,7 +1370,6 @@ async def async_get_all_descriptions(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
description = {"fields": yaml_description.get("fields", {})}
|
description = {"fields": yaml_description.get("fields", {})}
|
||||||
|
|
||||||
if (target := yaml_description.get("target")) is not None:
|
if (target := yaml_description.get("target")) is not None:
|
||||||
description["target"] = target
|
description["target"] = target
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ from homeassistant.util.hass_dict import HassKey
|
|||||||
from homeassistant.util.signal_type import SignalType, SignalTypeFormat
|
from homeassistant.util.signal_type import SignalType, SignalTypeFormat
|
||||||
|
|
||||||
from . import condition, config_validation as cv, service, template
|
from . import condition, config_validation as cv, service, template
|
||||||
from .condition import ConditionCheckerType, trace_condition_function
|
from .condition import ConditionCheckerTypeOptional, trace_condition_function
|
||||||
from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal
|
from .dispatcher import async_dispatcher_connect, async_dispatcher_send_internal
|
||||||
from .event import async_call_later, async_track_template
|
from .event import async_call_later, async_track_template
|
||||||
from .script_variables import ScriptRunVariables, ScriptVariables
|
from .script_variables import ScriptRunVariables, ScriptVariables
|
||||||
@@ -675,12 +675,14 @@ class _ScriptRun:
|
|||||||
|
|
||||||
### Condition actions ###
|
### Condition actions ###
|
||||||
|
|
||||||
async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
|
async def _async_get_condition(
|
||||||
|
self, config: ConfigType
|
||||||
|
) -> ConditionCheckerTypeOptional:
|
||||||
return await self._script._async_get_condition(config) # noqa: SLF001
|
return await self._script._async_get_condition(config) # noqa: SLF001
|
||||||
|
|
||||||
def _test_conditions(
|
def _test_conditions(
|
||||||
self,
|
self,
|
||||||
conditions: list[ConditionCheckerType],
|
conditions: list[ConditionCheckerTypeOptional],
|
||||||
name: str,
|
name: str,
|
||||||
condition_path: str | None = None,
|
condition_path: str | None = None,
|
||||||
) -> bool | None:
|
) -> bool | None:
|
||||||
@@ -1404,12 +1406,12 @@ def _referenced_extract_ids(data: Any, key: str, found: set[str]) -> None:
|
|||||||
|
|
||||||
|
|
||||||
class _ChooseData(TypedDict):
|
class _ChooseData(TypedDict):
|
||||||
choices: list[tuple[list[ConditionCheckerType], Script]]
|
choices: list[tuple[list[ConditionCheckerTypeOptional], Script]]
|
||||||
default: Script | None
|
default: Script | None
|
||||||
|
|
||||||
|
|
||||||
class _IfData(TypedDict):
|
class _IfData(TypedDict):
|
||||||
if_conditions: list[ConditionCheckerType]
|
if_conditions: list[ConditionCheckerTypeOptional]
|
||||||
if_then: Script
|
if_then: Script
|
||||||
if_else: Script | None
|
if_else: Script | None
|
||||||
|
|
||||||
@@ -1486,7 +1488,9 @@ class Script:
|
|||||||
self._max_exceeded = max_exceeded
|
self._max_exceeded = max_exceeded
|
||||||
if script_mode == SCRIPT_MODE_QUEUED:
|
if script_mode == SCRIPT_MODE_QUEUED:
|
||||||
self._queue_lck = asyncio.Lock()
|
self._queue_lck = asyncio.Lock()
|
||||||
self._config_cache: dict[frozenset[tuple[str, str]], ConditionCheckerType] = {}
|
self._config_cache: dict[
|
||||||
|
frozenset[tuple[str, str]], ConditionCheckerTypeOptional
|
||||||
|
] = {}
|
||||||
self._repeat_script: dict[int, Script] = {}
|
self._repeat_script: dict[int, Script] = {}
|
||||||
self._choose_data: dict[int, _ChooseData] = {}
|
self._choose_data: dict[int, _ChooseData] = {}
|
||||||
self._if_data: dict[int, _IfData] = {}
|
self._if_data: dict[int, _IfData] = {}
|
||||||
@@ -1857,7 +1861,9 @@ class Script:
|
|||||||
return
|
return
|
||||||
await asyncio.shield(create_eager_task(self._async_stop(aws, update_state)))
|
await asyncio.shield(create_eager_task(self._async_stop(aws, update_state)))
|
||||||
|
|
||||||
async def _async_get_condition(self, config: ConfigType) -> ConditionCheckerType:
|
async def _async_get_condition(
|
||||||
|
self, config: ConfigType
|
||||||
|
) -> ConditionCheckerTypeOptional:
|
||||||
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
|
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
|
||||||
if not (cond := self._config_cache.get(config_cache_key)):
|
if not (cond := self._config_cache.get(config_cache_key)):
|
||||||
cond = await condition.async_from_config(self._hass, config)
|
cond = await condition.async_from_config(self._hass, config)
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from homeassistant.helpers import (
|
|||||||
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
||||||
from homeassistant.helpers.condition import (
|
from homeassistant.helpers.condition import (
|
||||||
Condition,
|
Condition,
|
||||||
ConditionCheckerType,
|
ConditionChecker,
|
||||||
async_validate_condition_config,
|
async_validate_condition_config,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.template import Template
|
from homeassistant.helpers.template import Template
|
||||||
@@ -2126,16 +2126,16 @@ async def test_platform_multiple_conditions(hass: HomeAssistant) -> None:
|
|||||||
class MockCondition1(MockCondition):
|
class MockCondition1(MockCondition):
|
||||||
"""Mock condition 1."""
|
"""Mock condition 1."""
|
||||||
|
|
||||||
async def async_get_checker(self) -> ConditionCheckerType:
|
async def async_get_checker(self) -> ConditionChecker:
|
||||||
"""Evaluate state based on configuration."""
|
"""Evaluate state based on configuration."""
|
||||||
return lambda hass, vars: True
|
return lambda **kwargs: True
|
||||||
|
|
||||||
class MockCondition2(MockCondition):
|
class MockCondition2(MockCondition):
|
||||||
"""Mock condition 2."""
|
"""Mock condition 2."""
|
||||||
|
|
||||||
async def async_get_checker(self) -> ConditionCheckerType:
|
async def async_get_checker(self) -> ConditionChecker:
|
||||||
"""Evaluate state based on configuration."""
|
"""Evaluate state based on configuration."""
|
||||||
return lambda hass, vars: False
|
return lambda **kwargs: False
|
||||||
|
|
||||||
async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]:
|
async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]:
|
||||||
return {
|
return {
|
||||||
|
|||||||
Reference in New Issue
Block a user