1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-24 12:59:34 +00:00

Add support for multiple entity_ids in conditions (#36817)

This commit is contained in:
Franck Nijhof
2020-06-15 22:54:19 +02:00
committed by GitHub
parent 0a219081ea
commit ba73ac12ba
3 changed files with 153 additions and 14 deletions

View File

@@ -238,7 +238,7 @@ def async_numeric_state_from_config(
"""Wrap action method with state based condition."""
if config_validation:
config = cv.NUMERIC_STATE_CONDITION_SCHEMA(config)
entity_id = config.get(CONF_ENTITY_ID)
entity_ids = config.get(CONF_ENTITY_ID, [])
below = config.get(CONF_BELOW)
above = config.get(CONF_ABOVE)
value_template = config.get(CONF_VALUE_TEMPLATE)
@@ -250,8 +250,11 @@ def async_numeric_state_from_config(
if value_template is not None:
value_template.hass = hass
return async_numeric_state(
hass, entity_id, below, above, value_template, variables
return all(
async_numeric_state(
hass, entity_id, below, above, value_template, variables
)
for entity_id in entity_ids
)
return if_numeric_state
@@ -288,13 +291,15 @@ def state_from_config(
"""Wrap action method with state based condition."""
if config_validation:
config = cv.STATE_CONDITION_SCHEMA(config)
entity_id = config.get(CONF_ENTITY_ID)
entity_ids = config.get(CONF_ENTITY_ID, [])
req_state = cast(str, config.get(CONF_STATE))
for_period = config.get("for")
def if_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Test if condition."""
return state(hass, entity_id, req_state, for_period)
return all(
state(hass, entity_id, req_state, for_period) for entity_id in entity_ids
)
return if_state
@@ -506,12 +511,12 @@ def zone_from_config(
"""Wrap action method with zone based condition."""
if config_validation:
config = cv.ZONE_CONDITION_SCHEMA(config)
entity_id = config.get(CONF_ENTITY_ID)
entity_ids = config.get(CONF_ENTITY_ID, [])
zone_entity_id = config.get(CONF_ZONE)
def if_in_zone(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Test if condition."""
return zone(hass, zone_entity_id, entity_id)
return all(zone(hass, zone_entity_id, entity_id) for entity_id in entity_ids)
return if_in_zone
@@ -556,7 +561,7 @@ async def async_validate_condition_config(
@callback
def async_extract_entities(config: ConfigType) -> Set[str]:
"""Extract entities from a condition."""
referenced = set()
referenced: Set[str] = set()
to_process = deque([config])
while to_process:
@@ -567,10 +572,13 @@ def async_extract_entities(config: ConfigType) -> Set[str]:
to_process.extend(config["conditions"])
continue
entity_id = config.get(CONF_ENTITY_ID)
entity_ids = config.get(CONF_ENTITY_ID)
if entity_id is not None:
referenced.add(entity_id)
if isinstance(entity_ids, str):
entity_ids = [entity_ids]
if entity_ids is not None:
referenced.update(entity_ids)
return referenced