mirror of
https://github.com/home-assistant/core.git
synced 2025-12-20 02:48:57 +00:00
Add custom (external) wake words (#152919)
This commit is contained in:
@@ -17,7 +17,7 @@ from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.issue_registry import async_delete_issue
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import dashboard, ffmpeg_proxy
|
||||
from . import assist_satellite, dashboard, ffmpeg_proxy
|
||||
from .const import CONF_BLUETOOTH_MAC_ADDRESS, CONF_NOISE_PSK, DOMAIN
|
||||
from .domain_data import DomainData
|
||||
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData
|
||||
@@ -31,6 +31,7 @@ CLIENT_INFO = f"Home Assistant {ha_version}"
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the esphome component."""
|
||||
ffmpeg_proxy.async_setup(hass)
|
||||
await assist_satellite.async_setup(hass)
|
||||
await dashboard.async_setup(hass)
|
||||
return True
|
||||
|
||||
|
||||
@@ -5,9 +5,12 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from functools import partial
|
||||
import hashlib
|
||||
import io
|
||||
from itertools import chain
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import socket
|
||||
from typing import Any, cast
|
||||
import wave
|
||||
@@ -19,9 +22,12 @@ from aioesphomeapi import (
|
||||
VoiceAssistantAudioSettings,
|
||||
VoiceAssistantCommandFlag,
|
||||
VoiceAssistantEventType,
|
||||
VoiceAssistantExternalWakeWord,
|
||||
VoiceAssistantFeature,
|
||||
VoiceAssistantTimerEventType,
|
||||
)
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant.components import assist_satellite, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
@@ -29,6 +35,7 @@ from homeassistant.components.assist_pipeline import (
|
||||
PipelineEventType,
|
||||
PipelineStage,
|
||||
)
|
||||
from homeassistant.components.http import StaticPathConfig
|
||||
from homeassistant.components.intent import (
|
||||
TimerEventType,
|
||||
TimerInfo,
|
||||
@@ -39,8 +46,11 @@ from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
from homeassistant.helpers.network import get_url
|
||||
from homeassistant.helpers.singleton import singleton
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import DOMAIN, WAKE_WORDS_API_PATH, WAKE_WORDS_DIR_NAME
|
||||
from .entity import EsphomeAssistEntity, convert_api_error_ha_error
|
||||
from .entry_data import ESPHomeConfigEntry
|
||||
from .enum_mapper import EsphomeEnumMapper
|
||||
@@ -84,6 +94,16 @@ _TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventTy
|
||||
|
||||
_ANNOUNCEMENT_TIMEOUT_SEC = 5 * 60 # 5 minutes
|
||||
_CONFIG_TIMEOUT_SEC = 5
|
||||
_WAKE_WORD_CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required("type"): str,
|
||||
vol.Required("wake_word"): str,
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
_DATA_WAKE_WORDS: HassKey[dict[str, VoiceAssistantExternalWakeWord]] = HassKey(
|
||||
"wake_word_cache"
|
||||
)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
@@ -182,9 +202,14 @@ class EsphomeAssistSatellite(
|
||||
|
||||
async def _update_satellite_config(self) -> None:
|
||||
"""Get the latest satellite configuration from the device."""
|
||||
wake_words = await async_get_custom_wake_words(self.hass)
|
||||
if wake_words:
|
||||
_LOGGER.debug("Found custom wake words: %s", sorted(wake_words.keys()))
|
||||
|
||||
try:
|
||||
config = await self.cli.get_voice_assistant_configuration(
|
||||
_CONFIG_TIMEOUT_SEC
|
||||
_CONFIG_TIMEOUT_SEC,
|
||||
external_wake_words=list(wake_words.values()),
|
||||
)
|
||||
except TimeoutError:
|
||||
# Placeholder config will be used
|
||||
@@ -784,3 +809,78 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||
return
|
||||
|
||||
self.transport.sendto(data, self.remote_addr)
|
||||
|
||||
|
||||
async def async_get_custom_wake_words(
|
||||
hass: HomeAssistant,
|
||||
) -> dict[str, VoiceAssistantExternalWakeWord]:
|
||||
"""Get available custom wake words."""
|
||||
return await hass.async_add_executor_job(_get_custom_wake_words, hass)
|
||||
|
||||
|
||||
@singleton(_DATA_WAKE_WORDS)
|
||||
def _get_custom_wake_words(
|
||||
hass: HomeAssistant,
|
||||
) -> dict[str, VoiceAssistantExternalWakeWord]:
|
||||
"""Get available custom wake words (singleton)."""
|
||||
wake_words_dir = Path(hass.config.path(WAKE_WORDS_DIR_NAME))
|
||||
wake_words: dict[str, VoiceAssistantExternalWakeWord] = {}
|
||||
|
||||
# Look for config/model files
|
||||
for config_path in wake_words_dir.glob("*.json"):
|
||||
wake_word_id = config_path.stem
|
||||
model_path = config_path.with_suffix(".tflite")
|
||||
if not model_path.exists():
|
||||
# Missing model file
|
||||
continue
|
||||
|
||||
with open(config_path, encoding="utf-8") as config_file:
|
||||
config_dict = json.load(config_file)
|
||||
try:
|
||||
config = _WAKE_WORD_CONFIG_SCHEMA(config_dict)
|
||||
except vol.Invalid as err:
|
||||
# Invalid config
|
||||
_LOGGER.debug(
|
||||
"Invalid wake word config: path=%s, error=%s",
|
||||
config_path,
|
||||
humanize_error(config_dict, err),
|
||||
)
|
||||
continue
|
||||
|
||||
with open(model_path, "rb") as model_file:
|
||||
model_hash = hashlib.sha256(model_file.read()).hexdigest()
|
||||
|
||||
model_size = model_path.stat().st_size
|
||||
config_rel_path = config_path.relative_to(wake_words_dir)
|
||||
|
||||
# Only intended for the internal network
|
||||
base_url = get_url(hass, prefer_external=False, allow_cloud=False)
|
||||
|
||||
wake_words[wake_word_id] = VoiceAssistantExternalWakeWord.from_dict(
|
||||
{
|
||||
"id": wake_word_id,
|
||||
"wake_word": config["wake_word"],
|
||||
"trained_languages": config_dict.get("trained_languages", []),
|
||||
"model_type": config["type"],
|
||||
"model_size": model_size,
|
||||
"model_hash": model_hash,
|
||||
"url": f"{base_url}{WAKE_WORDS_API_PATH}/{config_rel_path}",
|
||||
}
|
||||
)
|
||||
|
||||
return wake_words
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant) -> None:
|
||||
"""Set up the satellite."""
|
||||
wake_words_dir = Path(hass.config.path(WAKE_WORDS_DIR_NAME))
|
||||
|
||||
# Satellites will pull model files over HTTP
|
||||
await hass.http.async_register_static_paths(
|
||||
[
|
||||
StaticPathConfig(
|
||||
url_path=WAKE_WORDS_API_PATH,
|
||||
path=str(wake_words_dir),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -27,3 +27,6 @@ STABLE_BLE_URL_VERSION = f"{STABLE_BLE_VERSION.major}.{STABLE_BLE_VERSION.minor}
|
||||
DEFAULT_URL = f"https://esphome.io/changelog/{STABLE_BLE_URL_VERSION}.html"
|
||||
|
||||
NO_WAKE_WORD: Final[str] = "no_wake_word"
|
||||
|
||||
WAKE_WORDS_DIR_NAME = "custom_wake_words"
|
||||
WAKE_WORDS_API_PATH = "/api/esphome/wake_words"
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
from dataclasses import replace
|
||||
from http import HTTPStatus
|
||||
import io
|
||||
import socket
|
||||
from unittest.mock import ANY, AsyncMock, Mock, patch
|
||||
@@ -55,6 +56,7 @@ from .common import get_satellite_entity
|
||||
from .conftest import MockESPHomeDeviceType
|
||||
|
||||
from tests.components.tts.common import MockResultStream
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -2087,3 +2089,80 @@ async def test_secondary_pipeline(
|
||||
|
||||
# Primary pipeline should be restored after
|
||||
assert (await get_pipeline(None)) == "Primary Pipeline"
|
||||
|
||||
|
||||
async def test_custom_wake_words(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: MockESPHomeDeviceType,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test exposing custom wake word models.
|
||||
|
||||
Expects 2 models in testing_config/custom_wake_words:
|
||||
- hey_home_assistant
|
||||
- choo_choo_homie
|
||||
"""
|
||||
http_client = await hass_client()
|
||||
expected_config = AssistSatelliteConfiguration(
|
||||
available_wake_words=[
|
||||
AssistSatelliteWakeWord("1234", "okay nabu", ["en"]),
|
||||
],
|
||||
active_wake_words=["1234"],
|
||||
max_active_wake_words=1,
|
||||
)
|
||||
gvac = mock_client.get_voice_assistant_configuration
|
||||
gvac.return_value = expected_config
|
||||
|
||||
mock_device = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.ANNOUNCE
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
|
||||
# Models should be present in testing_config/custom_wake_words
|
||||
gvac.assert_called_once()
|
||||
external_wake_words = gvac.call_args_list[0].kwargs["external_wake_words"]
|
||||
assert len(external_wake_words) == 2
|
||||
|
||||
assert {external_wake_words[0].id, external_wake_words[1].id} == {
|
||||
"hey_home_assistant",
|
||||
"choo_choo_homie",
|
||||
}
|
||||
|
||||
# Verify details
|
||||
for eww in external_wake_words:
|
||||
if eww.id == "hey_home_assistant":
|
||||
assert eww.wake_word == "Hey Home Assistant"
|
||||
else:
|
||||
assert eww.wake_word == "Choo Choo Homie"
|
||||
|
||||
assert eww.model_type == "micro"
|
||||
assert eww.model_size == 4 # tflite files contain "test"
|
||||
assert (
|
||||
eww.model_hash
|
||||
== "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08"
|
||||
)
|
||||
assert eww.trained_languages == ["en"]
|
||||
|
||||
# GET config
|
||||
config_url = eww.url[eww.url.find("/api") :]
|
||||
req = await http_client.get(config_url)
|
||||
assert req.status == HTTPStatus.OK
|
||||
config_dict = await req.json()
|
||||
|
||||
# GET model
|
||||
model = config_dict["model"]
|
||||
model_url = config_url[: config_url.rfind("/")] + f"/{model}"
|
||||
req = await http_client.get(model_url)
|
||||
assert req.status == HTTPStatus.OK
|
||||
|
||||
# Check non-existent wake word
|
||||
req = await http_client.get("/api/esphome/wake_words/wrong_wake_word.json")
|
||||
assert req.status == HTTPStatus.NOT_FOUND
|
||||
|
||||
1
tests/testing_config/custom_wake_words/bad_config.json
Normal file
1
tests/testing_config/custom_wake_words/bad_config.json
Normal file
@@ -0,0 +1 @@
|
||||
{}
|
||||
1
tests/testing_config/custom_wake_words/bad_config.tflite
Normal file
1
tests/testing_config/custom_wake_words/bad_config.tflite
Normal file
@@ -0,0 +1 @@
|
||||
test
|
||||
16
tests/testing_config/custom_wake_words/choo_choo_homie.json
Normal file
16
tests/testing_config/custom_wake_words/choo_choo_homie.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"type": "micro",
|
||||
"wake_word": "Choo Choo Homie",
|
||||
"author": "Michael Hansen",
|
||||
"website": "https://www.home-assistant.io",
|
||||
"model": "choo_choo_homie.tflite",
|
||||
"trained_languages": ["en"],
|
||||
"version": 2,
|
||||
"micro": {
|
||||
"probability_cutoff": 0.97,
|
||||
"feature_step_size": 10,
|
||||
"sliding_window_size": 5,
|
||||
"tensor_arena_size": 30000,
|
||||
"minimum_esphome_version": "2024.7.0"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
test
|
||||
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"type": "micro",
|
||||
"wake_word": "Hey Home Assistant",
|
||||
"author": "Michael Hansen",
|
||||
"website": "https://www.home-assistant.io",
|
||||
"model": "hey_home_assistant.tflite",
|
||||
"trained_languages": ["en"],
|
||||
"version": 2,
|
||||
"micro": {
|
||||
"probability_cutoff": 0.97,
|
||||
"feature_step_size": 10,
|
||||
"sliding_window_size": 5,
|
||||
"tensor_arena_size": 30000,
|
||||
"minimum_esphome_version": "2024.7.0"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
test
|
||||
16
tests/testing_config/custom_wake_words/missing_model.json
Normal file
16
tests/testing_config/custom_wake_words/missing_model.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"type": "micro",
|
||||
"wake_word": "Missing Model",
|
||||
"author": "Michael Hansen",
|
||||
"website": "https://www.home-assistant.io",
|
||||
"model": "missing_model.tflite",
|
||||
"trained_languages": ["en"],
|
||||
"version": 2,
|
||||
"micro": {
|
||||
"probability_cutoff": 0.97,
|
||||
"feature_step_size": 10,
|
||||
"sliding_window_size": 5,
|
||||
"tensor_arena_size": 30000,
|
||||
"minimum_esphome_version": "2024.7.0"
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user