mirror of
https://github.com/home-assistant/core.git
synced 2025-12-24 12:59:34 +00:00
HassTurnOn/Off intents to also handle cover entities (#86206)
* Move entity/area resolution to async_match_states * Special case for covers in HassTurnOn/Off * Enable light color/brightness on areas * Remove async_register from default agent * Remove CONFIG_SCHEMA from conversation component * Fix intent tests * Fix light test * Move entity/area resolution to async_match_states * Special case for covers in HassTurnOn/Off * Enable light color/brightness on areas * Remove async_register from default agent * Remove CONFIG_SCHEMA from conversation component * Fix intent tests * Fix light test * Fix humidifier intent handlers * Remove DATA_CONFIG for conversation * Copy ServiceIntentHandler code to light * Add proper errors to humidifier intent handlers
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Collection, Iterable
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@@ -11,7 +11,11 @@ from typing import Any, TypeVar
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import ATTR_ENTITY_ID, ATTR_SUPPORTED_FEATURES
|
||||
from homeassistant.const import (
|
||||
ATTR_DEVICE_CLASS,
|
||||
ATTR_ENTITY_ID,
|
||||
ATTR_SUPPORTED_FEATURES,
|
||||
)
|
||||
from homeassistant.core import Context, HomeAssistant, State, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.loader import bind_hass
|
||||
@@ -110,51 +114,117 @@ class IntentUnexpectedError(IntentError):
|
||||
"""Unexpected error while handling intent."""
|
||||
|
||||
|
||||
def _is_device_class(
|
||||
state: State,
|
||||
entity: entity_registry.RegistryEntry | None,
|
||||
device_classes: Collection[str],
|
||||
) -> bool:
|
||||
"""Return true if entity device class matches."""
|
||||
# Try entity first
|
||||
if (entity is not None) and (entity.device_class is not None):
|
||||
# Entity device class can be None or blank as "unset"
|
||||
if entity.device_class in device_classes:
|
||||
return True
|
||||
|
||||
# Fall back to state attribute
|
||||
device_class = state.attributes.get(ATTR_DEVICE_CLASS)
|
||||
return (device_class is not None) and (device_class in device_classes)
|
||||
|
||||
|
||||
def _has_name(
|
||||
state: State, entity: entity_registry.RegistryEntry | None, name: str
|
||||
) -> bool:
|
||||
"""Return true if entity name or alias matches."""
|
||||
if name in (state.entity_id, state.name.casefold()):
|
||||
return True
|
||||
|
||||
# Check aliases
|
||||
if (entity is not None) and entity.aliases:
|
||||
for alias in entity.aliases:
|
||||
if name == alias.casefold():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_match_state(
|
||||
hass: HomeAssistant, name: str, states: Iterable[State] | None = None
|
||||
) -> State:
|
||||
"""Find a state that matches the name."""
|
||||
def async_match_states(
|
||||
hass: HomeAssistant,
|
||||
name: str | None = None,
|
||||
area_name: str | None = None,
|
||||
area: area_registry.AreaEntry | None = None,
|
||||
domains: Collection[str] | None = None,
|
||||
device_classes: Collection[str] | None = None,
|
||||
states: Iterable[State] | None = None,
|
||||
entities: entity_registry.EntityRegistry | None = None,
|
||||
areas: area_registry.AreaRegistry | None = None,
|
||||
) -> Iterable[State]:
|
||||
"""Find states that match the constraints."""
|
||||
if states is None:
|
||||
# All states
|
||||
states = hass.states.async_all()
|
||||
|
||||
name = name.casefold()
|
||||
state: State | None = None
|
||||
registry = entity_registry.async_get(hass)
|
||||
if entities is None:
|
||||
entities = entity_registry.async_get(hass)
|
||||
|
||||
for maybe_state in states:
|
||||
# Check entity id and name
|
||||
if name in (maybe_state.entity_id, maybe_state.name.casefold()):
|
||||
state = maybe_state
|
||||
else:
|
||||
# Check aliases
|
||||
entry = registry.async_get(maybe_state.entity_id)
|
||||
if (entry is not None) and entry.aliases:
|
||||
for alias in entry.aliases:
|
||||
if name == alias.casefold():
|
||||
state = maybe_state
|
||||
break
|
||||
# Gather entities
|
||||
states_and_entities: list[tuple[State, entity_registry.RegistryEntry | None]] = []
|
||||
for state in states:
|
||||
entity = entities.async_get(state.entity_id)
|
||||
if (entity is not None) and entity.entity_category:
|
||||
# Skip diagnostic entities
|
||||
continue
|
||||
|
||||
if state is not None:
|
||||
break
|
||||
states_and_entities.append((state, entity))
|
||||
|
||||
if state is None:
|
||||
raise IntentHandleError(f"Unable to find an entity called {name}")
|
||||
# Filter by domain and device class
|
||||
if domains:
|
||||
states_and_entities = [
|
||||
(state, entity)
|
||||
for state, entity in states_and_entities
|
||||
if state.domain in domains
|
||||
]
|
||||
|
||||
return state
|
||||
if device_classes:
|
||||
# Check device class in state attribute and in entity entry (if available)
|
||||
states_and_entities = [
|
||||
(state, entity)
|
||||
for state, entity in states_and_entities
|
||||
if _is_device_class(state, entity, device_classes)
|
||||
]
|
||||
|
||||
if (area is None) and (area_name is not None):
|
||||
# Look up area by name
|
||||
if areas is None:
|
||||
areas = area_registry.async_get(hass)
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_match_area(
|
||||
hass: HomeAssistant, area_name: str
|
||||
) -> area_registry.AreaEntry | None:
|
||||
"""Find an area that matches the name."""
|
||||
registry = area_registry.async_get(hass)
|
||||
return registry.async_get_area(area_name) or registry.async_get_area_by_name(
|
||||
area_name
|
||||
)
|
||||
# id or name
|
||||
area = areas.async_get_area(area_name) or areas.async_get_area_by_name(
|
||||
area_name
|
||||
)
|
||||
assert area is not None, f"No area named {area_name}"
|
||||
|
||||
if area is not None:
|
||||
# Filter by area
|
||||
states_and_entities = [
|
||||
(state, entity)
|
||||
for state, entity in states_and_entities
|
||||
if (entity is not None) and (entity.area_id == area.id)
|
||||
]
|
||||
|
||||
if name is not None:
|
||||
# Filter by name
|
||||
name = name.casefold()
|
||||
|
||||
for state, entity in states_and_entities:
|
||||
if _has_name(state, entity, name):
|
||||
yield state
|
||||
break
|
||||
else:
|
||||
# Not filtered by name
|
||||
for state, _entity in states_and_entities:
|
||||
yield state
|
||||
|
||||
|
||||
@callback
|
||||
@@ -229,102 +299,103 @@ class ServiceIntentHandler(IntentHandler):
|
||||
hass = intent_obj.hass
|
||||
slots = self.async_validate_slots(intent_obj.slots)
|
||||
|
||||
if "area" in slots:
|
||||
# Entities in an area
|
||||
area_name = slots["area"]["value"]
|
||||
area = async_match_area(hass, area_name)
|
||||
assert area is not None
|
||||
assert area.id is not None
|
||||
name: str | None = slots.get("name", {}).get("value")
|
||||
if name == "all":
|
||||
# Don't match on name if targeting all entities
|
||||
name = None
|
||||
|
||||
# Optional domain filter
|
||||
domains: set[str] | None = None
|
||||
if "domain" in slots:
|
||||
domains = set(slots["domain"]["value"])
|
||||
# Look up area first to fail early
|
||||
area_name = slots.get("area", {}).get("value")
|
||||
area: area_registry.AreaEntry | None = None
|
||||
if area_name is not None:
|
||||
areas = area_registry.async_get(hass)
|
||||
area = areas.async_get_area(area_name) or areas.async_get_area_by_name(
|
||||
area_name
|
||||
)
|
||||
if area is None:
|
||||
raise IntentHandleError(f"No area named {area_name}")
|
||||
|
||||
# Optional device class filter
|
||||
device_classes: set[str] | None = None
|
||||
if "device_class" in slots:
|
||||
device_classes = set(slots["device_class"]["value"])
|
||||
# Optional domain/device class filters.
|
||||
# Convert to sets for speed.
|
||||
domains: set[str] | None = None
|
||||
device_classes: set[str] | None = None
|
||||
|
||||
success_results = [
|
||||
if "domain" in slots:
|
||||
domains = set(slots["domain"]["value"])
|
||||
|
||||
if "device_class" in slots:
|
||||
device_classes = set(slots["device_class"]["value"])
|
||||
|
||||
states = list(
|
||||
async_match_states(
|
||||
hass,
|
||||
name=name,
|
||||
area=area,
|
||||
domains=domains,
|
||||
device_classes=device_classes,
|
||||
)
|
||||
)
|
||||
|
||||
if not states:
|
||||
raise IntentHandleError("No entities matched")
|
||||
|
||||
response = await self.async_handle_states(intent_obj, states, area)
|
||||
|
||||
return response
|
||||
|
||||
async def async_handle_states(
|
||||
self,
|
||||
intent_obj: Intent,
|
||||
states: list[State],
|
||||
area: area_registry.AreaEntry | None = None,
|
||||
) -> IntentResponse:
|
||||
"""Complete action on matched entity states."""
|
||||
assert states
|
||||
success_results: list[IntentResponseTarget] = []
|
||||
response = intent_obj.create_response()
|
||||
|
||||
if area is not None:
|
||||
success_results.append(
|
||||
IntentResponseTarget(
|
||||
type=IntentResponseTargetType.AREA, name=area.name, id=area.id
|
||||
)
|
||||
]
|
||||
service_coros = []
|
||||
registry = entity_registry.async_get(hass)
|
||||
for entity_entry in entity_registry.async_entries_for_area(
|
||||
registry, area.id
|
||||
):
|
||||
if entity_entry.entity_category:
|
||||
# Skip diagnostic entities
|
||||
continue
|
||||
|
||||
if domains and (entity_entry.domain not in domains):
|
||||
# Skip entity not in the domain
|
||||
continue
|
||||
|
||||
if device_classes and (entity_entry.device_class not in device_classes):
|
||||
# Skip entity with wrong device class
|
||||
continue
|
||||
|
||||
service_coros.append(
|
||||
hass.services.async_call(
|
||||
self.domain,
|
||||
self.service,
|
||||
{ATTR_ENTITY_ID: entity_entry.entity_id},
|
||||
context=intent_obj.context,
|
||||
)
|
||||
)
|
||||
|
||||
state = hass.states.get(entity_entry.entity_id)
|
||||
assert state is not None
|
||||
|
||||
success_results.append(
|
||||
IntentResponseTarget(
|
||||
type=IntentResponseTargetType.ENTITY,
|
||||
name=state.name,
|
||||
id=entity_entry.entity_id,
|
||||
),
|
||||
)
|
||||
|
||||
if not service_coros:
|
||||
raise IntentHandleError("No entities matched")
|
||||
|
||||
# Handle service calls in parallel.
|
||||
# We will need to handle partial failures here.
|
||||
await asyncio.gather(*service_coros)
|
||||
|
||||
response = intent_obj.create_response()
|
||||
response.async_set_speech(self.speech.format(area.name))
|
||||
response.async_set_results(
|
||||
success_results=success_results,
|
||||
)
|
||||
speech_name = area.name
|
||||
else:
|
||||
# Single entity
|
||||
state = async_match_state(hass, slots["name"]["value"])
|
||||
speech_name = states[0].name
|
||||
|
||||
await hass.services.async_call(
|
||||
self.domain,
|
||||
self.service,
|
||||
{ATTR_ENTITY_ID: state.entity_id},
|
||||
context=intent_obj.context,
|
||||
service_coros = []
|
||||
for state in states:
|
||||
service_coros.append(self.async_call_service(intent_obj, state))
|
||||
success_results.append(
|
||||
IntentResponseTarget(
|
||||
type=IntentResponseTargetType.ENTITY,
|
||||
name=state.name,
|
||||
id=state.entity_id,
|
||||
),
|
||||
)
|
||||
|
||||
response = intent_obj.create_response()
|
||||
response.async_set_speech(self.speech.format(state.name))
|
||||
response.async_set_results(
|
||||
success_results=[
|
||||
IntentResponseTarget(
|
||||
type=IntentResponseTargetType.ENTITY,
|
||||
name=state.name,
|
||||
id=state.entity_id,
|
||||
),
|
||||
],
|
||||
)
|
||||
# Handle service calls in parallel.
|
||||
# We will need to handle partial failures here.
|
||||
await asyncio.gather(*service_coros)
|
||||
|
||||
response.async_set_results(
|
||||
success_results=success_results,
|
||||
)
|
||||
response.async_set_speech(self.speech.format(speech_name))
|
||||
|
||||
return response
|
||||
|
||||
async def async_call_service(self, intent_obj: Intent, state: State) -> None:
|
||||
"""Call service on entity."""
|
||||
hass = intent_obj.hass
|
||||
await hass.services.async_call(
|
||||
self.domain,
|
||||
self.service,
|
||||
{ATTR_ENTITY_ID: state.entity_id},
|
||||
context=intent_obj.context,
|
||||
)
|
||||
|
||||
|
||||
class IntentCategory(Enum):
|
||||
"""Category of an intent."""
|
||||
|
||||
Reference in New Issue
Block a user