1
0
mirror of https://github.com/home-assistant/core.git synced 2026-02-21 18:38:17 +00:00
Files
core/tests/components/elevenlabs/test_tts.py
2025-10-18 23:50:01 +02:00

672 lines
22 KiB
Python

"""Tests for the ElevenLabs TTS entity."""
from __future__ import annotations
import asyncio
from collections import deque
from collections.abc import AsyncIterator, Iterator
from http import HTTPStatus
from pathlib import Path
from typing import Any, Self
from unittest.mock import AsyncMock, MagicMock
from elevenlabs.core import ApiError
from elevenlabs.types import VoiceSettings
import pytest
from homeassistant.components import tts
from homeassistant.components.elevenlabs.const import (
ATTR_MODEL,
CONF_SIMILARITY,
CONF_STABILITY,
CONF_STYLE,
CONF_USE_SPEAKER_BOOST,
DEFAULT_SIMILARITY,
DEFAULT_STABILITY,
DEFAULT_STYLE,
DEFAULT_USE_SPEAKER_BOOST,
)
from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID,
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
)
from homeassistant.components.tts import TTSAudioRequest
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.core_config import async_process_ha_core_config
from tests.common import async_mock_service
from tests.components.tts.common import retrieve_media
from tests.typing import ClientSessionGenerator
class _FakeResponse:
def __init__(self, headers: dict[str, str]) -> None:
self.headers = headers
class _AsyncByteStream:
"""Async iterator that yields bytes and exposes response headers like ElevenLabs' stream."""
def __init__(self, chunks: list[bytes], request_id: str | None = None) -> None:
self._chunks = chunks
self._i = 0
def __aiter__(self) -> AsyncIterator[bytes]:
return self
async def __anext__(self) -> bytes:
if self._i >= len(self._chunks):
raise StopAsyncIteration
b = self._chunks[self._i]
self._i += 1
await asyncio.sleep(0) # let loop breathe; mirrors real async iterator
return b
class _AsyncStreamResponse:
"""Async context manager that mimics ElevenLabs raw stream responses."""
def __init__(self, chunks: list[bytes], request_id: str | None = None) -> None:
self.headers = {"request-id": request_id} if request_id else {}
self.data = _AsyncByteStream(chunks)
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
return None
@pytest.fixture
def capture_stream_calls(monkeypatch: pytest.MonkeyPatch):
"""Patches AsyncElevenLabs.text_to_speech.with_raw_response.stream and captures each call's kwargs.
Returns:
calls: list[dict] — kwargs passed into each stream() invocation
set_next_return(chunks, request_id): sets what the NEXT stream() call yields/returns
"""
calls: list[dict] = []
state = {"chunks": [b"X"], "request_id": "rid-1"} # defaults; override per test
def set_next_return(
*, chunks: list[bytes], request_id: str | None, error: Exception | None = None
) -> None:
state["chunks"] = chunks
state["request_id"] = request_id
state["error"] = error
def patch_stream(tts_entity):
def _mock_stream(**kwargs):
calls.append(kwargs)
if state.get("error") is not None:
raise state["error"]
return _AsyncStreamResponse(
chunks=list(state["chunks"]),
request_id=state["request_id"],
)
tts_entity._client.text_to_speech.with_raw_response.stream = _mock_stream
return calls, set_next_return, patch_stream
@pytest.fixture
def stream_sentence_helpers():
"""Return helpers for queue-driven sentence streaming."""
def factory(sentence_iter: Iterator[tuple], queue: asyncio.Queue[str | None]):
async def get_next_part() -> tuple[Any, ...]:
try:
return next(sentence_iter)
except StopIteration:
await queue.put(None)
return None, None, None
async def message_gen() -> AsyncIterator[str]:
while True:
part = await queue.get()
if part is None:
break
yield part
return get_next_part, message_gen
return factory
@pytest.fixture(autouse=True)
def tts_mutagen_mock_fixture_autouse(tts_mutagen_mock: MagicMock) -> None:
"""Mock writing tags."""
@pytest.fixture(autouse=True)
def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
"""Mock the TTS cache dir with empty dir."""
@pytest.fixture
async def calls(hass: HomeAssistant) -> list[ServiceCall]:
"""Mock media player calls."""
return async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@pytest.fixture(autouse=True)
async def setup_internal_url(hass: HomeAssistant) -> None:
"""Set up internal url."""
await async_process_ha_core_config(
hass, {"internal_url": "http://example.local:8123"}
)
@pytest.mark.parametrize(
"config_data",
[
{},
{tts.CONF_LANG: "de"},
{tts.CONF_LANG: "en"},
{tts.CONF_LANG: "ja"},
{tts.CONF_LANG: "es"},
],
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data"),
[
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.elevenlabs_text_to_speech",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {},
},
),
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.elevenlabs_text_to_speech",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2"},
},
),
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.elevenlabs_text_to_speech",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {ATTR_MODEL: "model2"},
},
),
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.elevenlabs_text_to_speech",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2", ATTR_MODEL: "model2"},
},
),
],
indirect=["setup"],
)
async def test_tts_service_speak(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
capture_stream_calls,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test tts service."""
stream_calls, _, patch_stream = capture_stream_calls
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
patch_stream(tts_entity)
assert tts_entity._voice_settings == VoiceSettings(
stability=DEFAULT_STABILITY,
similarity_boost=DEFAULT_SIMILARITY,
style=DEFAULT_STYLE,
use_speaker_boost=DEFAULT_USE_SPEAKER_BOOST,
)
await hass.services.async_call(
tts.DOMAIN,
tts_service,
service_data,
blocking=True,
)
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(stream_calls) == 1
voice_id = service_data[tts.ATTR_OPTIONS].get(tts.ATTR_VOICE, "voice1")
model_id = service_data[tts.ATTR_OPTIONS].get(ATTR_MODEL, "model1")
language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language)
call_kwargs = stream_calls[0]
assert call_kwargs["text"] == "There is a person at the front door."
assert call_kwargs["voice_id"] == voice_id
assert call_kwargs["model_id"] == model_id
assert call_kwargs["voice_settings"] == tts_entity._voice_settings
assert call_kwargs["output_format"] == "mp3_44100_128"
assert call_kwargs["language_code"] == language
@pytest.mark.parametrize(
("setup", "tts_service", "service_data"),
[
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.elevenlabs_text_to_speech",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_LANGUAGE: "de",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
},
),
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.elevenlabs_text_to_speech",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_LANGUAGE: "es",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
},
),
],
indirect=["setup"],
)
async def test_tts_service_speak_lang_config(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
capture_stream_calls,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test service call say with other langcodes in the config."""
stream_calls, _, patch_stream = capture_stream_calls
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
patch_stream(tts_entity)
await hass.services.async_call(
tts.DOMAIN,
tts_service,
service_data,
blocking=True,
)
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(stream_calls) == 1
language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language)
call_kwargs = stream_calls[0]
assert call_kwargs["text"] == "There is a person at the front door."
assert call_kwargs["voice_id"] == "voice1"
assert call_kwargs["model_id"] == "model1"
assert call_kwargs["voice_settings"] == tts_entity._voice_settings
assert call_kwargs["output_format"] == "mp3_44100_128"
assert call_kwargs["language_code"] == language
@pytest.mark.parametrize(
("setup", "tts_service", "service_data"),
[
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.elevenlabs_text_to_speech",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
},
),
],
indirect=["setup"],
)
async def test_tts_service_speak_error(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
capture_stream_calls,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test service call say with http response 400."""
stream_calls, set_next_return, patch_stream = capture_stream_calls
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
patch_stream(tts_entity)
set_next_return(chunks=[], request_id=None, error=ApiError())
await hass.services.async_call(
tts.DOMAIN,
tts_service,
service_data,
blocking=True,
)
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.INTERNAL_SERVER_ERROR
)
assert len(stream_calls) == 1
language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language)
call_kwargs = stream_calls[0]
assert call_kwargs["text"] == "There is a person at the front door."
assert call_kwargs["voice_id"] == "voice1"
assert call_kwargs["model_id"] == "model1"
assert call_kwargs["voice_settings"] == tts_entity._voice_settings
assert call_kwargs["output_format"] == "mp3_44100_128"
assert call_kwargs["language_code"] == language
@pytest.mark.parametrize(
"config_data",
[
{},
{tts.CONF_LANG: "de"},
{tts.CONF_LANG: "en"},
{tts.CONF_LANG: "ja"},
{tts.CONF_LANG: "es"},
],
)
@pytest.mark.parametrize(
("config_options", "tts_service", "service_data"),
[
(
{
CONF_SIMILARITY: DEFAULT_SIMILARITY / 2,
CONF_STABILITY: DEFAULT_STABILITY,
CONF_STYLE: DEFAULT_STYLE,
CONF_USE_SPEAKER_BOOST: DEFAULT_USE_SPEAKER_BOOST,
},
"speak",
{
ATTR_ENTITY_ID: "tts.elevenlabs_text_to_speech",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2"},
},
),
],
)
async def test_tts_service_speak_voice_settings(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
capture_stream_calls,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
mock_similarity: float,
) -> None:
"""Test tts service."""
stream_calls, _, patch_stream = capture_stream_calls
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
patch_stream(tts_entity)
assert tts_entity._voice_settings == VoiceSettings(
stability=DEFAULT_STABILITY,
similarity_boost=DEFAULT_SIMILARITY / 2,
style=DEFAULT_STYLE,
use_speaker_boost=DEFAULT_USE_SPEAKER_BOOST,
)
await hass.services.async_call(
tts.DOMAIN,
tts_service,
service_data,
blocking=True,
)
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(stream_calls) == 1
language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language)
call_kwargs = stream_calls[0]
assert call_kwargs["text"] == "There is a person at the front door."
assert call_kwargs["voice_id"] == "voice2"
assert call_kwargs["model_id"] == "model1"
assert call_kwargs["voice_settings"] == tts_entity._voice_settings
assert call_kwargs["output_format"] == "mp3_44100_128"
assert call_kwargs["language_code"] == language
@pytest.mark.parametrize(
("setup", "tts_service", "service_data"),
[
(
"mock_config_entry_setup",
"speak",
{
ATTR_ENTITY_ID: "tts.elevenlabs_text_to_speech",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is a person at the front door.",
tts.ATTR_OPTIONS: {},
},
),
],
indirect=["setup"],
)
async def test_tts_service_speak_without_options(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
capture_stream_calls,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test service call say with http response 200."""
stream_calls, _, patch_stream = capture_stream_calls
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
patch_stream(tts_entity)
await hass.services.async_call(
tts.DOMAIN,
tts_service,
service_data,
blocking=True,
)
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(stream_calls) == 1
language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language)
call_kwargs = stream_calls[0]
assert call_kwargs["text"] == "There is a person at the front door."
assert call_kwargs["voice_id"] == "voice1"
assert call_kwargs["model_id"] == "model1"
assert call_kwargs["voice_settings"] == VoiceSettings(
stability=0.5,
similarity_boost=0.75,
style=0.0,
use_speaker_boost=True,
)
assert call_kwargs["output_format"] == "mp3_44100_128"
assert call_kwargs["language_code"] == language
@pytest.mark.parametrize(
("setup", "model_id"),
[
("mock_config_entry_setup", "eleven_multilingual_v2"),
],
indirect=["setup"],
)
@pytest.mark.parametrize(
("message", "chunks", "request_ids"),
[
(
[
["One. ", "Two! ", "Three"],
["! ", "Four"],
["? ", "Five"],
["! ", "Six!"],
],
[b"\x05\x06", b"\x07\x08", b"\x09\x0a", b"\x0b\x0c"],
["rid-1", "rid-2", "rid-3", "rid-4"],
),
],
)
async def test_stream_tts_with_request_ids(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
capture_stream_calls,
stream_sentence_helpers,
model_id: str,
message: list[list[str]],
chunks: list[bytes],
request_ids: list[str],
) -> None:
"""Test streaming TTS with request-id stitching."""
calls, set_next_return, patch_stream = capture_stream_calls
# Access the TTS entity as in your existing tests; adjust if you use a fixture instead
tts_entity = hass.data[tts.DOMAIN].get_entity("tts.elevenlabs_text_to_speech")
patch_stream(tts_entity)
# Use a queue to control when each part is yielded
queue = asyncio.Queue()
prev_request_ids: deque[str] = deque(maxlen=3) # keep last 3 request IDs
sentence_iter = iter(zip(message, chunks, request_ids, strict=False))
get_next_part, message_gen = stream_sentence_helpers(sentence_iter, queue)
options = {tts.ATTR_VOICE: "voice1", "model": model_id}
req = TTSAudioRequest(message_gen=message_gen(), language="en", options=options)
resp = await tts_entity.async_stream_tts_audio(req)
assert resp.extension == "mp3"
item, chunk, request_id = await get_next_part()
if item is not None:
for part in item:
await queue.put(part)
else:
await queue.put(None)
set_next_return(chunks=[chunk], request_id=request_id)
next_item, next_chunk, next_request_id = await get_next_part()
# Consume bytes; after first chunk, switch next return to emulate second call
async for b in resp.data_gen:
assert b == chunk # each sentence yields its first chunk immediately
assert "previous_text" not in calls[-1] # no previous_text for first sentence
assert "next_text" not in calls[-1] # no next_text for first
assert calls[-1].get("previous_request_ids", []) == (
[] if len(calls) == 1 else list(prev_request_ids)
)
if request_id:
prev_request_ids.append(request_id or "")
item, chunk, request_id = next_item, next_chunk, next_request_id
if item is not None:
for part in item:
await queue.put(part)
set_next_return(chunks=[chunk], request_id=request_id)
next_item, next_chunk, next_request_id = await get_next_part()
if item is None:
await queue.put(None)
else:
await queue.put(None)
# We expect two stream() invocations (one per sentence batch)
assert len(calls) == len(message)
@pytest.mark.parametrize(
("message", "chunks", "request_ids"),
[
(
[
["This is the first sentence. ", "This is "],
["the second sentence. "],
],
[b"\x05\x06", b"\x07\x08"],
["rid-1", "rid-2"],
),
],
)
async def test_stream_tts_without_previous_info(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
capture_stream_calls,
stream_sentence_helpers,
monkeypatch: pytest.MonkeyPatch,
message: list[list[str]],
chunks: list[bytes],
request_ids: list[str],
) -> None:
"""Test streaming TTS without request-id stitching (eleven_v3)."""
calls, set_next_return, patch_stream = capture_stream_calls
tts_entity = hass.data[tts.DOMAIN].get_entity("tts.elevenlabs_text_to_speech")
patch_stream(tts_entity)
monkeypatch.setattr(
"homeassistant.components.elevenlabs.tts.MODELS_PREVIOUS_INFO_NOT_SUPPORTED",
("model1",),
raising=False,
)
queue = asyncio.Queue()
sentence_iter = iter(zip(message, chunks, request_ids, strict=False))
get_next_part, message_gen = stream_sentence_helpers(sentence_iter, queue)
options = {tts.ATTR_VOICE: "voice1", "model": "model1"}
req = TTSAudioRequest(message_gen=message_gen(), language="en", options=options)
resp = await tts_entity.async_stream_tts_audio(req)
assert resp.extension == "mp3"
item, chunk, request_id = await get_next_part()
if item is not None:
for part in item:
await queue.put(part)
else:
await queue.put(None)
set_next_return(chunks=[chunk], request_id=request_id)
next_item, next_chunk, next_request_id = await get_next_part()
# Consume bytes; after first chunk, switch next return to emulate second call
async for b in resp.data_gen:
assert b == chunk # each sentence yields its first chunk immediately
assert "previous_request_ids" not in calls[-1] # no previous_request_ids
item, chunk, request_id = next_item, next_chunk, next_request_id
if item is not None:
for part in item:
await queue.put(part)
set_next_return(chunks=[chunk], request_id=request_id)
next_item, next_chunk, next_request_id = await get_next_part()
if item is None:
await queue.put(None)
else:
await queue.put(None)
# We expect two stream() invocations (one per sentence batch)
assert len(calls) == len(message)