diff --git a/homeassistant/components/esphome/__init__.py b/homeassistant/components/esphome/__init__.py index cb1a3d10c97..3c01d3bb955 100644 --- a/homeassistant/components/esphome/__init__.py +++ b/homeassistant/components/esphome/__init__.py @@ -17,7 +17,7 @@ from homeassistant.helpers import config_validation as cv from homeassistant.helpers.issue_registry import async_delete_issue from homeassistant.helpers.typing import ConfigType -from . import dashboard, ffmpeg_proxy +from . import assist_satellite, dashboard, ffmpeg_proxy from .const import CONF_BLUETOOTH_MAC_ADDRESS, CONF_NOISE_PSK, DOMAIN from .domain_data import DomainData from .entry_data import ESPHomeConfigEntry, RuntimeEntryData @@ -31,6 +31,7 @@ CLIENT_INFO = f"Home Assistant {ha_version}" async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the esphome component.""" ffmpeg_proxy.async_setup(hass) + await assist_satellite.async_setup(hass) await dashboard.async_setup(hass) return True diff --git a/homeassistant/components/esphome/assist_satellite.py b/homeassistant/components/esphome/assist_satellite.py index aa565fa6107..9b3d954d221 100644 --- a/homeassistant/components/esphome/assist_satellite.py +++ b/homeassistant/components/esphome/assist_satellite.py @@ -5,9 +5,12 @@ from __future__ import annotations import asyncio from collections.abc import AsyncIterable from functools import partial +import hashlib import io from itertools import chain +import json import logging +from pathlib import Path import socket from typing import Any, cast import wave @@ -19,9 +22,12 @@ from aioesphomeapi import ( VoiceAssistantAudioSettings, VoiceAssistantCommandFlag, VoiceAssistantEventType, + VoiceAssistantExternalWakeWord, VoiceAssistantFeature, VoiceAssistantTimerEventType, ) +import voluptuous as vol +from voluptuous.humanize import humanize_error from homeassistant.components import assist_satellite, tts from homeassistant.components.assist_pipeline import ( @@ -29,6 +35,7 @@ from homeassistant.components.assist_pipeline import ( PipelineEventType, PipelineStage, ) +from homeassistant.components.http import StaticPathConfig from homeassistant.components.intent import ( TimerEventType, TimerInfo, @@ -39,8 +46,11 @@ from homeassistant.const import Platform from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import entity_registry as er from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback +from homeassistant.helpers.network import get_url +from homeassistant.helpers.singleton import singleton +from homeassistant.util.hass_dict import HassKey -from .const import DOMAIN +from .const import DOMAIN, WAKE_WORDS_API_PATH, WAKE_WORDS_DIR_NAME from .entity import EsphomeAssistEntity, convert_api_error_ha_error from .entry_data import ESPHomeConfigEntry from .enum_mapper import EsphomeEnumMapper @@ -84,6 +94,16 @@ _TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventTy _ANNOUNCEMENT_TIMEOUT_SEC = 5 * 60 # 5 minutes _CONFIG_TIMEOUT_SEC = 5 +_WAKE_WORD_CONFIG_SCHEMA = vol.Schema( + { + vol.Required("type"): str, + vol.Required("wake_word"): str, + }, + extra=vol.ALLOW_EXTRA, +) +_DATA_WAKE_WORDS: HassKey[dict[str, VoiceAssistantExternalWakeWord]] = HassKey( + "wake_word_cache" +) async def async_setup_entry( @@ -182,9 +202,14 @@ class EsphomeAssistSatellite( async def _update_satellite_config(self) -> None: """Get the latest satellite configuration from the device.""" + wake_words = await async_get_custom_wake_words(self.hass) + if wake_words: + _LOGGER.debug("Found custom wake words: %s", sorted(wake_words.keys())) + try: config = await self.cli.get_voice_assistant_configuration( - _CONFIG_TIMEOUT_SEC + _CONFIG_TIMEOUT_SEC, + external_wake_words=list(wake_words.values()), ) except TimeoutError: # Placeholder config will be used @@ -784,3 +809,78 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): return self.transport.sendto(data, self.remote_addr) + + +async def async_get_custom_wake_words( + hass: HomeAssistant, +) -> dict[str, VoiceAssistantExternalWakeWord]: + """Get available custom wake words.""" + return await hass.async_add_executor_job(_get_custom_wake_words, hass) + + +@singleton(_DATA_WAKE_WORDS) +def _get_custom_wake_words( + hass: HomeAssistant, +) -> dict[str, VoiceAssistantExternalWakeWord]: + """Get available custom wake words (singleton).""" + wake_words_dir = Path(hass.config.path(WAKE_WORDS_DIR_NAME)) + wake_words: dict[str, VoiceAssistantExternalWakeWord] = {} + + # Look for config/model files + for config_path in wake_words_dir.glob("*.json"): + wake_word_id = config_path.stem + model_path = config_path.with_suffix(".tflite") + if not model_path.exists(): + # Missing model file + continue + + with open(config_path, encoding="utf-8") as config_file: + config_dict = json.load(config_file) + try: + config = _WAKE_WORD_CONFIG_SCHEMA(config_dict) + except vol.Invalid as err: + # Invalid config + _LOGGER.debug( + "Invalid wake word config: path=%s, error=%s", + config_path, + humanize_error(config_dict, err), + ) + continue + + with open(model_path, "rb") as model_file: + model_hash = hashlib.sha256(model_file.read()).hexdigest() + + model_size = model_path.stat().st_size + config_rel_path = config_path.relative_to(wake_words_dir) + + # Only intended for the internal network + base_url = get_url(hass, prefer_external=False, allow_cloud=False) + + wake_words[wake_word_id] = VoiceAssistantExternalWakeWord.from_dict( + { + "id": wake_word_id, + "wake_word": config["wake_word"], + "trained_languages": config_dict.get("trained_languages", []), + "model_type": config["type"], + "model_size": model_size, + "model_hash": model_hash, + "url": f"{base_url}{WAKE_WORDS_API_PATH}/{config_rel_path}", + } + ) + + return wake_words + + +async def async_setup(hass: HomeAssistant) -> None: + """Set up the satellite.""" + wake_words_dir = Path(hass.config.path(WAKE_WORDS_DIR_NAME)) + + # Satellites will pull model files over HTTP + await hass.http.async_register_static_paths( + [ + StaticPathConfig( + url_path=WAKE_WORDS_API_PATH, + path=str(wake_words_dir), + ) + ] + ) diff --git a/homeassistant/components/esphome/const.py b/homeassistant/components/esphome/const.py index 86688ebb8a6..14595356035 100644 --- a/homeassistant/components/esphome/const.py +++ b/homeassistant/components/esphome/const.py @@ -27,3 +27,6 @@ STABLE_BLE_URL_VERSION = f"{STABLE_BLE_VERSION.major}.{STABLE_BLE_VERSION.minor} DEFAULT_URL = f"https://esphome.io/changelog/{STABLE_BLE_URL_VERSION}.html" NO_WAKE_WORD: Final[str] = "no_wake_word" + +WAKE_WORDS_DIR_NAME = "custom_wake_words" +WAKE_WORDS_API_PATH = "/api/esphome/wake_words" diff --git a/tests/components/esphome/test_assist_satellite.py b/tests/components/esphome/test_assist_satellite.py index fe5ac70d687..149befc5b9d 100644 --- a/tests/components/esphome/test_assist_satellite.py +++ b/tests/components/esphome/test_assist_satellite.py @@ -2,6 +2,7 @@ import asyncio from dataclasses import replace +from http import HTTPStatus import io import socket from unittest.mock import ANY, AsyncMock, Mock, patch @@ -55,6 +56,7 @@ from .common import get_satellite_entity from .conftest import MockESPHomeDeviceType from tests.components.tts.common import MockResultStream +from tests.typing import ClientSessionGenerator @pytest.fixture @@ -2087,3 +2089,80 @@ async def test_secondary_pipeline( # Primary pipeline should be restored after assert (await get_pipeline(None)) == "Primary Pipeline" + + +async def test_custom_wake_words( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: MockESPHomeDeviceType, + hass_client: ClientSessionGenerator, +) -> None: + """Test exposing custom wake word models. + + Expects 2 models in testing_config/custom_wake_words: + - hey_home_assistant + - choo_choo_homie + """ + http_client = await hass_client() + expected_config = AssistSatelliteConfiguration( + available_wake_words=[ + AssistSatelliteWakeWord("1234", "okay nabu", ["en"]), + ], + active_wake_words=["1234"], + max_active_wake_words=1, + ) + gvac = mock_client.get_voice_assistant_configuration + gvac.return_value = expected_config + + mock_device = await mock_esphome_device( + mock_client=mock_client, + device_info={ + "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT + | VoiceAssistantFeature.ANNOUNCE + }, + ) + await hass.async_block_till_done() + + satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) + assert satellite is not None + + # Models should be present in testing_config/custom_wake_words + gvac.assert_called_once() + external_wake_words = gvac.call_args_list[0].kwargs["external_wake_words"] + assert len(external_wake_words) == 2 + + assert {external_wake_words[0].id, external_wake_words[1].id} == { + "hey_home_assistant", + "choo_choo_homie", + } + + # Verify details + for eww in external_wake_words: + if eww.id == "hey_home_assistant": + assert eww.wake_word == "Hey Home Assistant" + else: + assert eww.wake_word == "Choo Choo Homie" + + assert eww.model_type == "micro" + assert eww.model_size == 4 # tflite files contain "test" + assert ( + eww.model_hash + == "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08" + ) + assert eww.trained_languages == ["en"] + + # GET config + config_url = eww.url[eww.url.find("/api") :] + req = await http_client.get(config_url) + assert req.status == HTTPStatus.OK + config_dict = await req.json() + + # GET model + model = config_dict["model"] + model_url = config_url[: config_url.rfind("/")] + f"/{model}" + req = await http_client.get(model_url) + assert req.status == HTTPStatus.OK + + # Check non-existent wake word + req = await http_client.get("/api/esphome/wake_words/wrong_wake_word.json") + assert req.status == HTTPStatus.NOT_FOUND diff --git a/tests/testing_config/custom_wake_words/bad_config.json b/tests/testing_config/custom_wake_words/bad_config.json new file mode 100644 index 00000000000..0967ef424bc --- /dev/null +++ b/tests/testing_config/custom_wake_words/bad_config.json @@ -0,0 +1 @@ +{} diff --git a/tests/testing_config/custom_wake_words/bad_config.tflite b/tests/testing_config/custom_wake_words/bad_config.tflite new file mode 100644 index 00000000000..30d74d25844 --- /dev/null +++ b/tests/testing_config/custom_wake_words/bad_config.tflite @@ -0,0 +1 @@ +test \ No newline at end of file diff --git a/tests/testing_config/custom_wake_words/choo_choo_homie.json b/tests/testing_config/custom_wake_words/choo_choo_homie.json new file mode 100644 index 00000000000..dbebb947a97 --- /dev/null +++ b/tests/testing_config/custom_wake_words/choo_choo_homie.json @@ -0,0 +1,16 @@ +{ + "type": "micro", + "wake_word": "Choo Choo Homie", + "author": "Michael Hansen", + "website": "https://www.home-assistant.io", + "model": "choo_choo_homie.tflite", + "trained_languages": ["en"], + "version": 2, + "micro": { + "probability_cutoff": 0.97, + "feature_step_size": 10, + "sliding_window_size": 5, + "tensor_arena_size": 30000, + "minimum_esphome_version": "2024.7.0" + } +} diff --git a/tests/testing_config/custom_wake_words/choo_choo_homie.tflite b/tests/testing_config/custom_wake_words/choo_choo_homie.tflite new file mode 100644 index 00000000000..30d74d25844 --- /dev/null +++ b/tests/testing_config/custom_wake_words/choo_choo_homie.tflite @@ -0,0 +1 @@ +test \ No newline at end of file diff --git a/tests/testing_config/custom_wake_words/hey_home_assistant.json b/tests/testing_config/custom_wake_words/hey_home_assistant.json new file mode 100644 index 00000000000..b49f8afccd0 --- /dev/null +++ b/tests/testing_config/custom_wake_words/hey_home_assistant.json @@ -0,0 +1,16 @@ +{ + "type": "micro", + "wake_word": "Hey Home Assistant", + "author": "Michael Hansen", + "website": "https://www.home-assistant.io", + "model": "hey_home_assistant.tflite", + "trained_languages": ["en"], + "version": 2, + "micro": { + "probability_cutoff": 0.97, + "feature_step_size": 10, + "sliding_window_size": 5, + "tensor_arena_size": 30000, + "minimum_esphome_version": "2024.7.0" + } +} diff --git a/tests/testing_config/custom_wake_words/hey_home_assistant.tflite b/tests/testing_config/custom_wake_words/hey_home_assistant.tflite new file mode 100644 index 00000000000..30d74d25844 --- /dev/null +++ b/tests/testing_config/custom_wake_words/hey_home_assistant.tflite @@ -0,0 +1 @@ +test \ No newline at end of file diff --git a/tests/testing_config/custom_wake_words/missing_model.json b/tests/testing_config/custom_wake_words/missing_model.json new file mode 100644 index 00000000000..07146275533 --- /dev/null +++ b/tests/testing_config/custom_wake_words/missing_model.json @@ -0,0 +1,16 @@ +{ + "type": "micro", + "wake_word": "Missing Model", + "author": "Michael Hansen", + "website": "https://www.home-assistant.io", + "model": "missing_model.tflite", + "trained_languages": ["en"], + "version": 2, + "micro": { + "probability_cutoff": 0.97, + "feature_step_size": 10, + "sliding_window_size": 5, + "tensor_arena_size": 30000, + "minimum_esphome_version": "2024.7.0" + } +}