mirror of
https://github.com/home-assistant/core.git
synced 2025-12-24 12:59:34 +00:00
Hassil intents (#85156)
* Add hassil to requirements * Add intent sentences * Update sentences * Use hassil to recognize intents in conversation * Fix tests * Bump hassil due to dependency conflict * Add dataclasses-json package contraints * Bump hassil (removes dataclasses-json dependency) * Remove climate sentences until intents are supported * Move I/O outside event loop * Bump hassil to 0.2.3 * Fix light tests * Handle areas in intents * Clean up code according to suggestions * Remove sentences from repo * Use home-assistant-intents package * Apply suggestions from code review * Flake8 Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
"""Module to coordinate user intentions."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Iterable
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import voluptuous as vol
|
||||
@@ -16,7 +16,7 @@ from homeassistant.core import Context, HomeAssistant, State, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.loader import bind_hass
|
||||
|
||||
from . import config_validation as cv
|
||||
from . import area_registry, config_validation as cv, entity_registry
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_SlotsType = dict[str, Any]
|
||||
@@ -119,7 +119,25 @@ def async_match_state(
|
||||
if states is None:
|
||||
states = hass.states.async_all()
|
||||
|
||||
state = _fuzzymatch(name, states, lambda state: state.name)
|
||||
name = name.casefold()
|
||||
state: State | None = None
|
||||
registry = entity_registry.async_get(hass)
|
||||
|
||||
for maybe_state in states:
|
||||
# Check entity id and name
|
||||
if name in (maybe_state.entity_id, maybe_state.name.casefold()):
|
||||
state = maybe_state
|
||||
else:
|
||||
# Check aliases
|
||||
entry = registry.async_get(maybe_state.entity_id)
|
||||
if (entry is not None) and entry.aliases:
|
||||
for alias in entry.aliases:
|
||||
if name == alias.casefold():
|
||||
state = maybe_state
|
||||
break
|
||||
|
||||
if state is not None:
|
||||
break
|
||||
|
||||
if state is None:
|
||||
raise IntentHandleError(f"Unable to find an entity called {name}")
|
||||
@@ -127,6 +145,18 @@ def async_match_state(
|
||||
return state
|
||||
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_match_area(
|
||||
hass: HomeAssistant, area_name: str
|
||||
) -> area_registry.AreaEntry | None:
|
||||
"""Find an area that matches the name."""
|
||||
registry = area_registry.async_get(hass)
|
||||
return registry.async_get_area(area_name) or registry.async_get_area_by_name(
|
||||
area_name
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def async_test_feature(state: State, feature: int, feature_name: str) -> None:
|
||||
"""Test if state supports a feature."""
|
||||
@@ -173,29 +203,17 @@ class IntentHandler:
|
||||
return f"<{self.__class__.__name__} - {self.intent_type}>"
|
||||
|
||||
|
||||
def _fuzzymatch(name: str, items: Iterable[_T], key: Callable[[_T], str]) -> _T | None:
|
||||
"""Fuzzy matching function."""
|
||||
matches = []
|
||||
pattern = ".*?".join(name)
|
||||
regex = re.compile(pattern, re.IGNORECASE)
|
||||
for idx, item in enumerate(items):
|
||||
if match := regex.search(key(item)):
|
||||
# Add key length so we prefer shorter keys with the same group and start.
|
||||
# Add index so we pick first match in case same group, start, and key length.
|
||||
matches.append(
|
||||
(len(match.group()), match.start(), len(key(item)), idx, item)
|
||||
)
|
||||
|
||||
return sorted(matches)[0][4] if matches else None
|
||||
|
||||
|
||||
class ServiceIntentHandler(IntentHandler):
|
||||
"""Service Intent handler registration.
|
||||
|
||||
Service specific intent handler that calls a service by name/entity_id.
|
||||
"""
|
||||
|
||||
slot_schema = {vol.Required("name"): cv.string}
|
||||
slot_schema = {
|
||||
vol.Any("name", "area"): cv.string,
|
||||
vol.Optional("domain"): vol.All(cv.ensure_list, [cv.string]),
|
||||
vol.Optional("device_class"): vol.All(cv.ensure_list, [cv.string]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, intent_type: str, domain: str, service: str, speech: str
|
||||
@@ -210,26 +228,101 @@ class ServiceIntentHandler(IntentHandler):
|
||||
"""Handle the hass intent."""
|
||||
hass = intent_obj.hass
|
||||
slots = self.async_validate_slots(intent_obj.slots)
|
||||
state = async_match_state(hass, slots["name"]["value"])
|
||||
|
||||
await hass.services.async_call(
|
||||
self.domain,
|
||||
self.service,
|
||||
{ATTR_ENTITY_ID: state.entity_id},
|
||||
context=intent_obj.context,
|
||||
)
|
||||
if "area" in slots:
|
||||
# Entities in an area
|
||||
area_name = slots["area"]["value"]
|
||||
area = async_match_area(hass, area_name)
|
||||
assert area is not None
|
||||
assert area.id is not None
|
||||
|
||||
response = intent_obj.create_response()
|
||||
response.async_set_speech(self.speech.format(state.name))
|
||||
response.async_set_results(
|
||||
success_results=[
|
||||
# Optional domain filter
|
||||
domains: set[str] | None = None
|
||||
if "domain" in slots:
|
||||
domains = set(slots["domain"]["value"])
|
||||
|
||||
# Optional device class filter
|
||||
device_classes: set[str] | None = None
|
||||
if "device_class" in slots:
|
||||
device_classes = set(slots["device_class"]["value"])
|
||||
|
||||
success_results = [
|
||||
IntentResponseTarget(
|
||||
type=IntentResponseTargetType.ENTITY,
|
||||
name=state.name,
|
||||
id=state.entity_id,
|
||||
),
|
||||
],
|
||||
)
|
||||
type=IntentResponseTargetType.AREA, name=area.name, id=area.id
|
||||
)
|
||||
]
|
||||
service_coros = []
|
||||
registry = entity_registry.async_get(hass)
|
||||
for entity_entry in entity_registry.async_entries_for_area(
|
||||
registry, area.id
|
||||
):
|
||||
if entity_entry.entity_category:
|
||||
# Skip diagnostic entities
|
||||
continue
|
||||
|
||||
if domains and (entity_entry.domain not in domains):
|
||||
# Skip entity not in the domain
|
||||
continue
|
||||
|
||||
if device_classes and (entity_entry.device_class not in device_classes):
|
||||
# Skip entity with wrong device class
|
||||
continue
|
||||
|
||||
service_coros.append(
|
||||
hass.services.async_call(
|
||||
self.domain,
|
||||
self.service,
|
||||
{ATTR_ENTITY_ID: entity_entry.entity_id},
|
||||
context=intent_obj.context,
|
||||
)
|
||||
)
|
||||
|
||||
state = hass.states.get(entity_entry.entity_id)
|
||||
assert state is not None
|
||||
|
||||
success_results.append(
|
||||
IntentResponseTarget(
|
||||
type=IntentResponseTargetType.ENTITY,
|
||||
name=state.name,
|
||||
id=entity_entry.entity_id,
|
||||
),
|
||||
)
|
||||
|
||||
if not service_coros:
|
||||
raise IntentHandleError("No entities matched")
|
||||
|
||||
# Handle service calls in parallel.
|
||||
# We will need to handle partial failures here.
|
||||
await asyncio.gather(*service_coros)
|
||||
|
||||
response = intent_obj.create_response()
|
||||
response.async_set_speech(self.speech.format(area.name))
|
||||
response.async_set_results(
|
||||
success_results=success_results,
|
||||
)
|
||||
else:
|
||||
# Single entity
|
||||
state = async_match_state(hass, slots["name"]["value"])
|
||||
|
||||
await hass.services.async_call(
|
||||
self.domain,
|
||||
self.service,
|
||||
{ATTR_ENTITY_ID: state.entity_id},
|
||||
context=intent_obj.context,
|
||||
)
|
||||
|
||||
response = intent_obj.create_response()
|
||||
response.async_set_speech(self.speech.format(state.name))
|
||||
response.async_set_results(
|
||||
success_results=[
|
||||
IntentResponseTarget(
|
||||
type=IntentResponseTargetType.ENTITY,
|
||||
name=state.name,
|
||||
id=state.entity_id,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user