1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-24 12:59:34 +00:00

Add custom (external) wake words (#152919)

This commit is contained in:
Michael Hansen
2025-10-27 13:15:56 -05:00
committed by GitHub
parent c782489973
commit 87e7fe6e37
11 changed files with 238 additions and 3 deletions

View File

@@ -17,7 +17,7 @@ from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.issue_registry import async_delete_issue from homeassistant.helpers.issue_registry import async_delete_issue
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from . import dashboard, ffmpeg_proxy from . import assist_satellite, dashboard, ffmpeg_proxy
from .const import CONF_BLUETOOTH_MAC_ADDRESS, CONF_NOISE_PSK, DOMAIN from .const import CONF_BLUETOOTH_MAC_ADDRESS, CONF_NOISE_PSK, DOMAIN
from .domain_data import DomainData from .domain_data import DomainData
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData from .entry_data import ESPHomeConfigEntry, RuntimeEntryData
@@ -31,6 +31,7 @@ CLIENT_INFO = f"Home Assistant {ha_version}"
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the esphome component.""" """Set up the esphome component."""
ffmpeg_proxy.async_setup(hass) ffmpeg_proxy.async_setup(hass)
await assist_satellite.async_setup(hass)
await dashboard.async_setup(hass) await dashboard.async_setup(hass)
return True return True

View File

@@ -5,9 +5,12 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
from functools import partial from functools import partial
import hashlib
import io import io
from itertools import chain from itertools import chain
import json
import logging import logging
from pathlib import Path
import socket import socket
from typing import Any, cast from typing import Any, cast
import wave import wave
@@ -19,9 +22,12 @@ from aioesphomeapi import (
VoiceAssistantAudioSettings, VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag, VoiceAssistantCommandFlag,
VoiceAssistantEventType, VoiceAssistantEventType,
VoiceAssistantExternalWakeWord,
VoiceAssistantFeature, VoiceAssistantFeature,
VoiceAssistantTimerEventType, VoiceAssistantTimerEventType,
) )
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.components import assist_satellite, tts from homeassistant.components import assist_satellite, tts
from homeassistant.components.assist_pipeline import ( from homeassistant.components.assist_pipeline import (
@@ -29,6 +35,7 @@ from homeassistant.components.assist_pipeline import (
PipelineEventType, PipelineEventType,
PipelineStage, PipelineStage,
) )
from homeassistant.components.http import StaticPathConfig
from homeassistant.components.intent import ( from homeassistant.components.intent import (
TimerEventType, TimerEventType,
TimerInfo, TimerInfo,
@@ -39,8 +46,11 @@ from homeassistant.const import Platform
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.helpers.network import get_url
from homeassistant.helpers.singleton import singleton
from homeassistant.util.hass_dict import HassKey
from .const import DOMAIN from .const import DOMAIN, WAKE_WORDS_API_PATH, WAKE_WORDS_DIR_NAME
from .entity import EsphomeAssistEntity, convert_api_error_ha_error from .entity import EsphomeAssistEntity, convert_api_error_ha_error
from .entry_data import ESPHomeConfigEntry from .entry_data import ESPHomeConfigEntry
from .enum_mapper import EsphomeEnumMapper from .enum_mapper import EsphomeEnumMapper
@@ -84,6 +94,16 @@ _TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventTy
_ANNOUNCEMENT_TIMEOUT_SEC = 5 * 60 # 5 minutes _ANNOUNCEMENT_TIMEOUT_SEC = 5 * 60 # 5 minutes
_CONFIG_TIMEOUT_SEC = 5 _CONFIG_TIMEOUT_SEC = 5
_WAKE_WORD_CONFIG_SCHEMA = vol.Schema(
{
vol.Required("type"): str,
vol.Required("wake_word"): str,
},
extra=vol.ALLOW_EXTRA,
)
_DATA_WAKE_WORDS: HassKey[dict[str, VoiceAssistantExternalWakeWord]] = HassKey(
"wake_word_cache"
)
async def async_setup_entry( async def async_setup_entry(
@@ -182,9 +202,14 @@ class EsphomeAssistSatellite(
async def _update_satellite_config(self) -> None: async def _update_satellite_config(self) -> None:
"""Get the latest satellite configuration from the device.""" """Get the latest satellite configuration from the device."""
wake_words = await async_get_custom_wake_words(self.hass)
if wake_words:
_LOGGER.debug("Found custom wake words: %s", sorted(wake_words.keys()))
try: try:
config = await self.cli.get_voice_assistant_configuration( config = await self.cli.get_voice_assistant_configuration(
_CONFIG_TIMEOUT_SEC _CONFIG_TIMEOUT_SEC,
external_wake_words=list(wake_words.values()),
) )
except TimeoutError: except TimeoutError:
# Placeholder config will be used # Placeholder config will be used
@@ -784,3 +809,78 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
return return
self.transport.sendto(data, self.remote_addr) self.transport.sendto(data, self.remote_addr)
async def async_get_custom_wake_words(
hass: HomeAssistant,
) -> dict[str, VoiceAssistantExternalWakeWord]:
"""Get available custom wake words."""
return await hass.async_add_executor_job(_get_custom_wake_words, hass)
@singleton(_DATA_WAKE_WORDS)
def _get_custom_wake_words(
hass: HomeAssistant,
) -> dict[str, VoiceAssistantExternalWakeWord]:
"""Get available custom wake words (singleton)."""
wake_words_dir = Path(hass.config.path(WAKE_WORDS_DIR_NAME))
wake_words: dict[str, VoiceAssistantExternalWakeWord] = {}
# Look for config/model files
for config_path in wake_words_dir.glob("*.json"):
wake_word_id = config_path.stem
model_path = config_path.with_suffix(".tflite")
if not model_path.exists():
# Missing model file
continue
with open(config_path, encoding="utf-8") as config_file:
config_dict = json.load(config_file)
try:
config = _WAKE_WORD_CONFIG_SCHEMA(config_dict)
except vol.Invalid as err:
# Invalid config
_LOGGER.debug(
"Invalid wake word config: path=%s, error=%s",
config_path,
humanize_error(config_dict, err),
)
continue
with open(model_path, "rb") as model_file:
model_hash = hashlib.sha256(model_file.read()).hexdigest()
model_size = model_path.stat().st_size
config_rel_path = config_path.relative_to(wake_words_dir)
# Only intended for the internal network
base_url = get_url(hass, prefer_external=False, allow_cloud=False)
wake_words[wake_word_id] = VoiceAssistantExternalWakeWord.from_dict(
{
"id": wake_word_id,
"wake_word": config["wake_word"],
"trained_languages": config_dict.get("trained_languages", []),
"model_type": config["type"],
"model_size": model_size,
"model_hash": model_hash,
"url": f"{base_url}{WAKE_WORDS_API_PATH}/{config_rel_path}",
}
)
return wake_words
async def async_setup(hass: HomeAssistant) -> None:
"""Set up the satellite."""
wake_words_dir = Path(hass.config.path(WAKE_WORDS_DIR_NAME))
# Satellites will pull model files over HTTP
await hass.http.async_register_static_paths(
[
StaticPathConfig(
url_path=WAKE_WORDS_API_PATH,
path=str(wake_words_dir),
)
]
)

View File

@@ -27,3 +27,6 @@ STABLE_BLE_URL_VERSION = f"{STABLE_BLE_VERSION.major}.{STABLE_BLE_VERSION.minor}
DEFAULT_URL = f"https://esphome.io/changelog/{STABLE_BLE_URL_VERSION}.html" DEFAULT_URL = f"https://esphome.io/changelog/{STABLE_BLE_URL_VERSION}.html"
NO_WAKE_WORD: Final[str] = "no_wake_word" NO_WAKE_WORD: Final[str] = "no_wake_word"
WAKE_WORDS_DIR_NAME = "custom_wake_words"
WAKE_WORDS_API_PATH = "/api/esphome/wake_words"

View File

@@ -2,6 +2,7 @@
import asyncio import asyncio
from dataclasses import replace from dataclasses import replace
from http import HTTPStatus
import io import io
import socket import socket
from unittest.mock import ANY, AsyncMock, Mock, patch from unittest.mock import ANY, AsyncMock, Mock, patch
@@ -55,6 +56,7 @@ from .common import get_satellite_entity
from .conftest import MockESPHomeDeviceType from .conftest import MockESPHomeDeviceType
from tests.components.tts.common import MockResultStream from tests.components.tts.common import MockResultStream
from tests.typing import ClientSessionGenerator
@pytest.fixture @pytest.fixture
@@ -2087,3 +2089,80 @@ async def test_secondary_pipeline(
# Primary pipeline should be restored after # Primary pipeline should be restored after
assert (await get_pipeline(None)) == "Primary Pipeline" assert (await get_pipeline(None)) == "Primary Pipeline"
async def test_custom_wake_words(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: MockESPHomeDeviceType,
hass_client: ClientSessionGenerator,
) -> None:
"""Test exposing custom wake word models.
Expects 2 models in testing_config/custom_wake_words:
- hey_home_assistant
- choo_choo_homie
"""
http_client = await hass_client()
expected_config = AssistSatelliteConfiguration(
available_wake_words=[
AssistSatelliteWakeWord("1234", "okay nabu", ["en"]),
],
active_wake_words=["1234"],
max_active_wake_words=1,
)
gvac = mock_client.get_voice_assistant_configuration
gvac.return_value = expected_config
mock_device = await mock_esphome_device(
mock_client=mock_client,
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.ANNOUNCE
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
# Models should be present in testing_config/custom_wake_words
gvac.assert_called_once()
external_wake_words = gvac.call_args_list[0].kwargs["external_wake_words"]
assert len(external_wake_words) == 2
assert {external_wake_words[0].id, external_wake_words[1].id} == {
"hey_home_assistant",
"choo_choo_homie",
}
# Verify details
for eww in external_wake_words:
if eww.id == "hey_home_assistant":
assert eww.wake_word == "Hey Home Assistant"
else:
assert eww.wake_word == "Choo Choo Homie"
assert eww.model_type == "micro"
assert eww.model_size == 4 # tflite files contain "test"
assert (
eww.model_hash
== "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08"
)
assert eww.trained_languages == ["en"]
# GET config
config_url = eww.url[eww.url.find("/api") :]
req = await http_client.get(config_url)
assert req.status == HTTPStatus.OK
config_dict = await req.json()
# GET model
model = config_dict["model"]
model_url = config_url[: config_url.rfind("/")] + f"/{model}"
req = await http_client.get(model_url)
assert req.status == HTTPStatus.OK
# Check non-existent wake word
req = await http_client.get("/api/esphome/wake_words/wrong_wake_word.json")
assert req.status == HTTPStatus.NOT_FOUND

View File

@@ -0,0 +1 @@
{}

View File

@@ -0,0 +1 @@
test

View File

@@ -0,0 +1,16 @@
{
"type": "micro",
"wake_word": "Choo Choo Homie",
"author": "Michael Hansen",
"website": "https://www.home-assistant.io",
"model": "choo_choo_homie.tflite",
"trained_languages": ["en"],
"version": 2,
"micro": {
"probability_cutoff": 0.97,
"feature_step_size": 10,
"sliding_window_size": 5,
"tensor_arena_size": 30000,
"minimum_esphome_version": "2024.7.0"
}
}

View File

@@ -0,0 +1 @@
test

View File

@@ -0,0 +1,16 @@
{
"type": "micro",
"wake_word": "Hey Home Assistant",
"author": "Michael Hansen",
"website": "https://www.home-assistant.io",
"model": "hey_home_assistant.tflite",
"trained_languages": ["en"],
"version": 2,
"micro": {
"probability_cutoff": 0.97,
"feature_step_size": 10,
"sliding_window_size": 5,
"tensor_arena_size": 30000,
"minimum_esphome_version": "2024.7.0"
}
}

View File

@@ -0,0 +1 @@
test

View File

@@ -0,0 +1,16 @@
{
"type": "micro",
"wake_word": "Missing Model",
"author": "Michael Hansen",
"website": "https://www.home-assistant.io",
"model": "missing_model.tflite",
"trained_languages": ["en"],
"version": 2,
"micro": {
"probability_cutoff": 0.97,
"feature_step_size": 10,
"sliding_window_size": 5,
"tensor_arena_size": 30000,
"minimum_esphome_version": "2024.7.0"
}
}