1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-20 02:48:57 +00:00

Add custom (external) wake words (#152919)

This commit is contained in:
Michael Hansen
2025-10-27 13:15:56 -05:00
committed by GitHub
parent c782489973
commit 87e7fe6e37
11 changed files with 238 additions and 3 deletions

View File

@@ -17,7 +17,7 @@ from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.issue_registry import async_delete_issue
from homeassistant.helpers.typing import ConfigType
from . import dashboard, ffmpeg_proxy
from . import assist_satellite, dashboard, ffmpeg_proxy
from .const import CONF_BLUETOOTH_MAC_ADDRESS, CONF_NOISE_PSK, DOMAIN
from .domain_data import DomainData
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData
@@ -31,6 +31,7 @@ CLIENT_INFO = f"Home Assistant {ha_version}"
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the esphome component."""
ffmpeg_proxy.async_setup(hass)
await assist_satellite.async_setup(hass)
await dashboard.async_setup(hass)
return True

View File

@@ -5,9 +5,12 @@ from __future__ import annotations
import asyncio
from collections.abc import AsyncIterable
from functools import partial
import hashlib
import io
from itertools import chain
import json
import logging
from pathlib import Path
import socket
from typing import Any, cast
import wave
@@ -19,9 +22,12 @@ from aioesphomeapi import (
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
VoiceAssistantExternalWakeWord,
VoiceAssistantFeature,
VoiceAssistantTimerEventType,
)
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.components import assist_satellite, tts
from homeassistant.components.assist_pipeline import (
@@ -29,6 +35,7 @@ from homeassistant.components.assist_pipeline import (
PipelineEventType,
PipelineStage,
)
from homeassistant.components.http import StaticPathConfig
from homeassistant.components.intent import (
TimerEventType,
TimerInfo,
@@ -39,8 +46,11 @@ from homeassistant.const import Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.helpers.network import get_url
from homeassistant.helpers.singleton import singleton
from homeassistant.util.hass_dict import HassKey
from .const import DOMAIN
from .const import DOMAIN, WAKE_WORDS_API_PATH, WAKE_WORDS_DIR_NAME
from .entity import EsphomeAssistEntity, convert_api_error_ha_error
from .entry_data import ESPHomeConfigEntry
from .enum_mapper import EsphomeEnumMapper
@@ -84,6 +94,16 @@ _TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventTy
_ANNOUNCEMENT_TIMEOUT_SEC = 5 * 60 # 5 minutes
_CONFIG_TIMEOUT_SEC = 5
_WAKE_WORD_CONFIG_SCHEMA = vol.Schema(
{
vol.Required("type"): str,
vol.Required("wake_word"): str,
},
extra=vol.ALLOW_EXTRA,
)
_DATA_WAKE_WORDS: HassKey[dict[str, VoiceAssistantExternalWakeWord]] = HassKey(
"wake_word_cache"
)
async def async_setup_entry(
@@ -182,9 +202,14 @@ class EsphomeAssistSatellite(
async def _update_satellite_config(self) -> None:
"""Get the latest satellite configuration from the device."""
wake_words = await async_get_custom_wake_words(self.hass)
if wake_words:
_LOGGER.debug("Found custom wake words: %s", sorted(wake_words.keys()))
try:
config = await self.cli.get_voice_assistant_configuration(
_CONFIG_TIMEOUT_SEC
_CONFIG_TIMEOUT_SEC,
external_wake_words=list(wake_words.values()),
)
except TimeoutError:
# Placeholder config will be used
@@ -784,3 +809,78 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
return
self.transport.sendto(data, self.remote_addr)
async def async_get_custom_wake_words(
hass: HomeAssistant,
) -> dict[str, VoiceAssistantExternalWakeWord]:
"""Get available custom wake words."""
return await hass.async_add_executor_job(_get_custom_wake_words, hass)
@singleton(_DATA_WAKE_WORDS)
def _get_custom_wake_words(
hass: HomeAssistant,
) -> dict[str, VoiceAssistantExternalWakeWord]:
"""Get available custom wake words (singleton)."""
wake_words_dir = Path(hass.config.path(WAKE_WORDS_DIR_NAME))
wake_words: dict[str, VoiceAssistantExternalWakeWord] = {}
# Look for config/model files
for config_path in wake_words_dir.glob("*.json"):
wake_word_id = config_path.stem
model_path = config_path.with_suffix(".tflite")
if not model_path.exists():
# Missing model file
continue
with open(config_path, encoding="utf-8") as config_file:
config_dict = json.load(config_file)
try:
config = _WAKE_WORD_CONFIG_SCHEMA(config_dict)
except vol.Invalid as err:
# Invalid config
_LOGGER.debug(
"Invalid wake word config: path=%s, error=%s",
config_path,
humanize_error(config_dict, err),
)
continue
with open(model_path, "rb") as model_file:
model_hash = hashlib.sha256(model_file.read()).hexdigest()
model_size = model_path.stat().st_size
config_rel_path = config_path.relative_to(wake_words_dir)
# Only intended for the internal network
base_url = get_url(hass, prefer_external=False, allow_cloud=False)
wake_words[wake_word_id] = VoiceAssistantExternalWakeWord.from_dict(
{
"id": wake_word_id,
"wake_word": config["wake_word"],
"trained_languages": config_dict.get("trained_languages", []),
"model_type": config["type"],
"model_size": model_size,
"model_hash": model_hash,
"url": f"{base_url}{WAKE_WORDS_API_PATH}/{config_rel_path}",
}
)
return wake_words
async def async_setup(hass: HomeAssistant) -> None:
"""Set up the satellite."""
wake_words_dir = Path(hass.config.path(WAKE_WORDS_DIR_NAME))
# Satellites will pull model files over HTTP
await hass.http.async_register_static_paths(
[
StaticPathConfig(
url_path=WAKE_WORDS_API_PATH,
path=str(wake_words_dir),
)
]
)

View File

@@ -27,3 +27,6 @@ STABLE_BLE_URL_VERSION = f"{STABLE_BLE_VERSION.major}.{STABLE_BLE_VERSION.minor}
DEFAULT_URL = f"https://esphome.io/changelog/{STABLE_BLE_URL_VERSION}.html"
NO_WAKE_WORD: Final[str] = "no_wake_word"
WAKE_WORDS_DIR_NAME = "custom_wake_words"
WAKE_WORDS_API_PATH = "/api/esphome/wake_words"

View File

@@ -2,6 +2,7 @@
import asyncio
from dataclasses import replace
from http import HTTPStatus
import io
import socket
from unittest.mock import ANY, AsyncMock, Mock, patch
@@ -55,6 +56,7 @@ from .common import get_satellite_entity
from .conftest import MockESPHomeDeviceType
from tests.components.tts.common import MockResultStream
from tests.typing import ClientSessionGenerator
@pytest.fixture
@@ -2087,3 +2089,80 @@ async def test_secondary_pipeline(
# Primary pipeline should be restored after
assert (await get_pipeline(None)) == "Primary Pipeline"
async def test_custom_wake_words(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: MockESPHomeDeviceType,
hass_client: ClientSessionGenerator,
) -> None:
"""Test exposing custom wake word models.
Expects 2 models in testing_config/custom_wake_words:
- hey_home_assistant
- choo_choo_homie
"""
http_client = await hass_client()
expected_config = AssistSatelliteConfiguration(
available_wake_words=[
AssistSatelliteWakeWord("1234", "okay nabu", ["en"]),
],
active_wake_words=["1234"],
max_active_wake_words=1,
)
gvac = mock_client.get_voice_assistant_configuration
gvac.return_value = expected_config
mock_device = await mock_esphome_device(
mock_client=mock_client,
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.ANNOUNCE
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
# Models should be present in testing_config/custom_wake_words
gvac.assert_called_once()
external_wake_words = gvac.call_args_list[0].kwargs["external_wake_words"]
assert len(external_wake_words) == 2
assert {external_wake_words[0].id, external_wake_words[1].id} == {
"hey_home_assistant",
"choo_choo_homie",
}
# Verify details
for eww in external_wake_words:
if eww.id == "hey_home_assistant":
assert eww.wake_word == "Hey Home Assistant"
else:
assert eww.wake_word == "Choo Choo Homie"
assert eww.model_type == "micro"
assert eww.model_size == 4 # tflite files contain "test"
assert (
eww.model_hash
== "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08"
)
assert eww.trained_languages == ["en"]
# GET config
config_url = eww.url[eww.url.find("/api") :]
req = await http_client.get(config_url)
assert req.status == HTTPStatus.OK
config_dict = await req.json()
# GET model
model = config_dict["model"]
model_url = config_url[: config_url.rfind("/")] + f"/{model}"
req = await http_client.get(model_url)
assert req.status == HTTPStatus.OK
# Check non-existent wake word
req = await http_client.get("/api/esphome/wake_words/wrong_wake_word.json")
assert req.status == HTTPStatus.NOT_FOUND

View File

@@ -0,0 +1 @@
{}

View File

@@ -0,0 +1 @@
test

View File

@@ -0,0 +1,16 @@
{
"type": "micro",
"wake_word": "Choo Choo Homie",
"author": "Michael Hansen",
"website": "https://www.home-assistant.io",
"model": "choo_choo_homie.tflite",
"trained_languages": ["en"],
"version": 2,
"micro": {
"probability_cutoff": 0.97,
"feature_step_size": 10,
"sliding_window_size": 5,
"tensor_arena_size": 30000,
"minimum_esphome_version": "2024.7.0"
}
}

View File

@@ -0,0 +1 @@
test

View File

@@ -0,0 +1,16 @@
{
"type": "micro",
"wake_word": "Hey Home Assistant",
"author": "Michael Hansen",
"website": "https://www.home-assistant.io",
"model": "hey_home_assistant.tflite",
"trained_languages": ["en"],
"version": 2,
"micro": {
"probability_cutoff": 0.97,
"feature_step_size": 10,
"sliding_window_size": 5,
"tensor_arena_size": 30000,
"minimum_esphome_version": "2024.7.0"
}
}

View File

@@ -0,0 +1 @@
test

View File

@@ -0,0 +1,16 @@
{
"type": "micro",
"wake_word": "Missing Model",
"author": "Michael Hansen",
"website": "https://www.home-assistant.io",
"model": "missing_model.tflite",
"trained_languages": ["en"],
"version": 2,
"micro": {
"probability_cutoff": 0.97,
"feature_step_size": 10,
"sliding_window_size": 5,
"tensor_arena_size": 30000,
"minimum_esphome_version": "2024.7.0"
}
}