mirror of
https://github.com/home-assistant/core.git
synced 2025-12-19 18:38:58 +00:00
Add streaming to Elevenlabs TTS (#154663)
This commit is contained in:
@@ -182,7 +182,6 @@ homeassistant.components.efergy.*
|
||||
homeassistant.components.eheimdigital.*
|
||||
homeassistant.components.electrasmart.*
|
||||
homeassistant.components.electric_kiwi.*
|
||||
homeassistant.components.elevenlabs.*
|
||||
homeassistant.components.elgato.*
|
||||
homeassistant.components.elkm1.*
|
||||
homeassistant.components.emulated_hue.*
|
||||
|
||||
@@ -21,6 +21,9 @@ DEFAULT_STT_MODEL = "scribe_v1"
|
||||
DEFAULT_STYLE = 0
|
||||
DEFAULT_USE_SPEAKER_BOOST = True
|
||||
|
||||
MAX_REQUEST_IDS = 3
|
||||
MODELS_PREVIOUS_INFO_NOT_SUPPORTED = ("eleven_v3",)
|
||||
|
||||
STT_LANGUAGES = [
|
||||
"af-ZA", # Afrikaans
|
||||
"am-ET", # Amharic
|
||||
|
||||
@@ -7,5 +7,5 @@
|
||||
"integration_type": "service",
|
||||
"iot_class": "cloud_polling",
|
||||
"loggers": ["elevenlabs"],
|
||||
"requirements": ["elevenlabs==2.3.0"]
|
||||
"requirements": ["elevenlabs==2.3.0", "sentence-stream==1.2.0"]
|
||||
}
|
||||
|
||||
@@ -85,4 +85,4 @@ rules:
|
||||
# Platinum
|
||||
async-dependency: done
|
||||
inject-websession: done
|
||||
strict-typing: done
|
||||
strict-typing: todo
|
||||
|
||||
@@ -2,17 +2,23 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from elevenlabs import AsyncElevenLabs
|
||||
from elevenlabs.core import ApiError
|
||||
from elevenlabs.types import Model, Voice as ElevenLabsVoice, VoiceSettings
|
||||
from sentence_stream import SentenceBoundaryDetector
|
||||
|
||||
from homeassistant.components.tts import (
|
||||
ATTR_VOICE,
|
||||
TextToSpeechEntity,
|
||||
TTSAudioRequest,
|
||||
TTSAudioResponse,
|
||||
TtsAudioType,
|
||||
Voice,
|
||||
)
|
||||
@@ -35,10 +41,12 @@ from .const import (
|
||||
DEFAULT_STYLE,
|
||||
DEFAULT_USE_SPEAKER_BOOST,
|
||||
DOMAIN,
|
||||
MAX_REQUEST_IDS,
|
||||
MODELS_PREVIOUS_INFO_NOT_SUPPORTED,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
PARALLEL_UPDATES = 0
|
||||
PARALLEL_UPDATES = 6
|
||||
|
||||
|
||||
def to_voice_settings(options: Mapping[str, Any]) -> VoiceSettings:
|
||||
@@ -122,7 +130,12 @@ class ElevenLabsTTSEntity(TextToSpeechEntity):
|
||||
self._attr_supported_languages = [
|
||||
lang.language_id for lang in self._model.languages or []
|
||||
]
|
||||
self._attr_default_language = self.supported_languages[0]
|
||||
# Use the first supported language as the default if available
|
||||
self._attr_default_language = (
|
||||
self._attr_supported_languages[0]
|
||||
if self._attr_supported_languages
|
||||
else "en"
|
||||
)
|
||||
|
||||
def async_get_supported_voices(self, language: str) -> list[Voice]:
|
||||
"""Return a list of supported voices for a language."""
|
||||
@@ -151,3 +164,151 @@ class ElevenLabsTTSEntity(TextToSpeechEntity):
|
||||
)
|
||||
raise HomeAssistantError(exc) from exc
|
||||
return "mp3", bytes_combined
|
||||
|
||||
async def async_stream_tts_audio(
|
||||
self, request: TTSAudioRequest
|
||||
) -> TTSAudioResponse:
|
||||
"""Generate speech from an incoming message."""
|
||||
_LOGGER.debug(
|
||||
"Getting TTS audio for language %s and options: %s",
|
||||
request.language,
|
||||
request.options,
|
||||
)
|
||||
return TTSAudioResponse("mp3", self._process_tts_stream(request))
|
||||
|
||||
async def _process_tts_stream(
|
||||
self, request: TTSAudioRequest
|
||||
) -> AsyncGenerator[bytes]:
|
||||
"""Generate speech from an incoming message."""
|
||||
text_stream = request.message_gen
|
||||
boundary_detector = SentenceBoundaryDetector()
|
||||
sentences: list[str] = []
|
||||
sentences_ready = asyncio.Event()
|
||||
sentences_complete = False
|
||||
|
||||
language_code: str | None = request.language
|
||||
voice_id = request.options.get(ATTR_VOICE, self._default_voice_id)
|
||||
model = request.options.get(ATTR_MODEL, self._model.model_id)
|
||||
|
||||
use_request_ids = model not in MODELS_PREVIOUS_INFO_NOT_SUPPORTED
|
||||
previous_request_ids: deque[str] = deque(maxlen=MAX_REQUEST_IDS)
|
||||
|
||||
base_stream_params = {
|
||||
"voice_id": voice_id,
|
||||
"model_id": model,
|
||||
"output_format": "mp3_44100_128",
|
||||
"voice_settings": self._voice_settings,
|
||||
}
|
||||
if language_code:
|
||||
base_stream_params["language_code"] = language_code
|
||||
|
||||
_LOGGER.debug("Starting TTS Stream with options: %s", base_stream_params)
|
||||
|
||||
async def _add_sentences() -> None:
|
||||
nonlocal sentences_complete
|
||||
|
||||
try:
|
||||
# Text chunks may not be on word or sentence boundaries
|
||||
async for text_chunk in text_stream:
|
||||
for sentence in boundary_detector.add_chunk(text_chunk):
|
||||
if not sentence.strip():
|
||||
continue
|
||||
|
||||
sentences.append(sentence)
|
||||
|
||||
if not sentences:
|
||||
continue
|
||||
|
||||
sentences_ready.set()
|
||||
|
||||
# Final sentence
|
||||
if text := boundary_detector.finish():
|
||||
sentences.append(text)
|
||||
finally:
|
||||
sentences_complete = True
|
||||
sentences_ready.set()
|
||||
|
||||
_add_sentences_task = self.hass.async_create_background_task(
|
||||
_add_sentences(), name="elevenlabs_tts_add_sentences"
|
||||
)
|
||||
|
||||
# Process new sentences as they're available, but synthesize the first
|
||||
# one immediately. While that's playing, synthesize (up to) the next 3
|
||||
# sentences. After that, synthesize all completed sentences as they're
|
||||
# available.
|
||||
sentence_schedule = [1, 3]
|
||||
while True:
|
||||
await sentences_ready.wait()
|
||||
|
||||
# Don't wait again if no more sentences are coming
|
||||
if not sentences_complete:
|
||||
sentences_ready.clear()
|
||||
|
||||
if not sentences:
|
||||
if sentences_complete:
|
||||
# Exit TTS loop
|
||||
_LOGGER.debug("No more sentences to process")
|
||||
break
|
||||
|
||||
# More sentences may be coming
|
||||
continue
|
||||
|
||||
new_sentences = sentences[:]
|
||||
sentences.clear()
|
||||
|
||||
while new_sentences:
|
||||
if sentence_schedule:
|
||||
max_sentences = sentence_schedule.pop(0)
|
||||
sentences_to_process = new_sentences[:max_sentences]
|
||||
new_sentences = new_sentences[len(sentences_to_process) :]
|
||||
else:
|
||||
# Process all available sentences together
|
||||
sentences_to_process = new_sentences[:]
|
||||
new_sentences.clear()
|
||||
|
||||
# Combine all new sentences completed to this point
|
||||
text = " ".join(sentences_to_process).strip()
|
||||
|
||||
if not text:
|
||||
continue
|
||||
|
||||
# Build kwargs common to both modes
|
||||
kwargs = base_stream_params | {
|
||||
"text": text,
|
||||
}
|
||||
|
||||
# Provide previous_request_ids if supported.
|
||||
if previous_request_ids:
|
||||
# Send previous request ids.
|
||||
kwargs["previous_request_ids"] = list(previous_request_ids)
|
||||
|
||||
# Synthesize audio while text chunks are still being accumulated
|
||||
_LOGGER.debug("Synthesizing TTS for text: %s", text)
|
||||
try:
|
||||
async with self._client.text_to_speech.with_raw_response.stream(
|
||||
**kwargs
|
||||
) as stream:
|
||||
async for chunk_bytes in stream.data:
|
||||
yield chunk_bytes
|
||||
|
||||
if use_request_ids:
|
||||
if (rid := stream.headers.get("request-id")) is not None:
|
||||
previous_request_ids.append(rid)
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
"No request-id returned from server; clearing previous requests"
|
||||
)
|
||||
previous_request_ids.clear()
|
||||
except ApiError as exc:
|
||||
_LOGGER.warning(
|
||||
"Error during processing of TTS request %s", exc, exc_info=True
|
||||
)
|
||||
_add_sentences_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await _add_sentences_task
|
||||
raise HomeAssistantError(exc) from exc
|
||||
|
||||
# Capture and store server request-id for next calls (only when supported)
|
||||
_LOGGER.debug("Completed TTS stream for text: %s", text)
|
||||
|
||||
_LOGGER.debug("Completed TTS stream")
|
||||
|
||||
10
mypy.ini
generated
10
mypy.ini
generated
@@ -1576,16 +1576,6 @@ disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.elevenlabs.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.elgato.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
|
||||
3
requirements_all.txt
generated
3
requirements_all.txt
generated
@@ -2812,6 +2812,9 @@ sensorpush-ha==1.3.2
|
||||
# homeassistant.components.sensoterra
|
||||
sensoterra==2.0.1
|
||||
|
||||
# homeassistant.components.elevenlabs
|
||||
sentence-stream==1.2.0
|
||||
|
||||
# homeassistant.components.sentry
|
||||
sentry-sdk==1.45.1
|
||||
|
||||
|
||||
3
requirements_test_all.txt
generated
3
requirements_test_all.txt
generated
@@ -2337,6 +2337,9 @@ sensorpush-ha==1.3.2
|
||||
# homeassistant.components.sensoterra
|
||||
sensoterra==2.0.1
|
||||
|
||||
# homeassistant.components.elevenlabs
|
||||
sentence-stream==1.2.0
|
||||
|
||||
# homeassistant.components.sentry
|
||||
sentry-sdk==1.45.1
|
||||
|
||||
|
||||
@@ -2,9 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Self
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from elevenlabs.core import ApiError
|
||||
@@ -28,6 +31,7 @@ from homeassistant.components.media_player import (
|
||||
DOMAIN as DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
)
|
||||
from homeassistant.components.tts import TTSAudioRequest
|
||||
from homeassistant.const import ATTR_ENTITY_ID
|
||||
from homeassistant.core import HomeAssistant, ServiceCall
|
||||
from homeassistant.core_config import async_process_ha_core_config
|
||||
@@ -37,17 +41,99 @@ from tests.components.tts.common import retrieve_media
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
class FakeAudioGenerator:
|
||||
"""Mock audio generator for ElevenLabs TTS."""
|
||||
class _FakeResponse:
|
||||
def __init__(self, headers: dict[str, str]) -> None:
|
||||
self.headers = headers
|
||||
|
||||
def __aiter__(self):
|
||||
"""Mock async iterator for audio parts."""
|
||||
|
||||
async def _gen():
|
||||
yield b"audio-part-1"
|
||||
yield b"audio-part-2"
|
||||
class _AsyncByteStream:
|
||||
"""Async iterator that yields bytes and exposes response headers like ElevenLabs' stream."""
|
||||
|
||||
return _gen()
|
||||
def __init__(self, chunks: list[bytes], request_id: str | None = None) -> None:
|
||||
self._chunks = chunks
|
||||
self._i = 0
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[bytes]:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> bytes:
|
||||
if self._i >= len(self._chunks):
|
||||
raise StopAsyncIteration
|
||||
b = self._chunks[self._i]
|
||||
self._i += 1
|
||||
await asyncio.sleep(0) # let loop breathe; mirrors real async iterator
|
||||
return b
|
||||
|
||||
|
||||
class _AsyncStreamResponse:
|
||||
"""Async context manager that mimics ElevenLabs raw stream responses."""
|
||||
|
||||
def __init__(self, chunks: list[bytes], request_id: str | None = None) -> None:
|
||||
self.headers = {"request-id": request_id} if request_id else {}
|
||||
self.data = _AsyncByteStream(chunks)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def capture_stream_calls(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Patches AsyncElevenLabs.text_to_speech.with_raw_response.stream and captures each call's kwargs.
|
||||
|
||||
Returns:
|
||||
calls: list[dict] — kwargs passed into each stream() invocation
|
||||
set_next_return(chunks, request_id): sets what the NEXT stream() call yields/returns
|
||||
"""
|
||||
calls: list[dict] = []
|
||||
state = {"chunks": [b"X"], "request_id": "rid-1"} # defaults; override per test
|
||||
|
||||
def set_next_return(
|
||||
*, chunks: list[bytes], request_id: str | None, error: Exception | None = None
|
||||
) -> None:
|
||||
state["chunks"] = chunks
|
||||
state["request_id"] = request_id
|
||||
state["error"] = error
|
||||
|
||||
def patch_stream(tts_entity):
|
||||
def _mock_stream(**kwargs):
|
||||
calls.append(kwargs)
|
||||
if state.get("error") is not None:
|
||||
raise state["error"]
|
||||
return _AsyncStreamResponse(
|
||||
chunks=list(state["chunks"]),
|
||||
request_id=state["request_id"],
|
||||
)
|
||||
|
||||
tts_entity._client.text_to_speech.with_raw_response.stream = _mock_stream
|
||||
|
||||
return calls, set_next_return, patch_stream
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stream_sentence_helpers():
|
||||
"""Return helpers for queue-driven sentence streaming."""
|
||||
|
||||
def factory(sentence_iter: Iterator[tuple], queue: asyncio.Queue[str | None]):
|
||||
async def get_next_part() -> tuple[Any, ...]:
|
||||
try:
|
||||
return next(sentence_iter)
|
||||
except StopIteration:
|
||||
await queue.put(None)
|
||||
return None, None, None
|
||||
|
||||
async def message_gen() -> AsyncIterator[str]:
|
||||
while True:
|
||||
part = await queue.get()
|
||||
if part is None:
|
||||
break
|
||||
yield part
|
||||
|
||||
return get_next_part, message_gen
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -134,15 +220,15 @@ async def test_tts_service_speak(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
capture_stream_calls,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test tts service."""
|
||||
stream_calls, _, patch_stream = capture_stream_calls
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._client.text_to_speech.convert = MagicMock(
|
||||
return_value=FakeAudioGenerator()
|
||||
)
|
||||
patch_stream(tts_entity)
|
||||
|
||||
assert tts_entity._voice_settings == VoiceSettings(
|
||||
stability=DEFAULT_STABILITY,
|
||||
@@ -158,20 +244,22 @@ async def test_tts_service_speak(
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
assert len(stream_calls) == 1
|
||||
voice_id = service_data[tts.ATTR_OPTIONS].get(tts.ATTR_VOICE, "voice1")
|
||||
model_id = service_data[tts.ATTR_OPTIONS].get(ATTR_MODEL, "model1")
|
||||
language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language)
|
||||
|
||||
tts_entity._client.text_to_speech.convert.assert_called_once_with(
|
||||
text="There is a person at the front door.",
|
||||
voice_id=voice_id,
|
||||
model_id=model_id,
|
||||
voice_settings=tts_entity._voice_settings,
|
||||
)
|
||||
call_kwargs = stream_calls[0]
|
||||
assert call_kwargs["text"] == "There is a person at the front door."
|
||||
assert call_kwargs["voice_id"] == voice_id
|
||||
assert call_kwargs["model_id"] == model_id
|
||||
assert call_kwargs["voice_settings"] == tts_entity._voice_settings
|
||||
assert call_kwargs["output_format"] == "mp3_44100_128"
|
||||
assert call_kwargs["language_code"] == language
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -206,15 +294,15 @@ async def test_tts_service_speak_lang_config(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
capture_stream_calls,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test service call say with other langcodes in the config."""
|
||||
stream_calls, _, patch_stream = capture_stream_calls
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._client.text_to_speech.convert = MagicMock(
|
||||
return_value=FakeAudioGenerator()
|
||||
)
|
||||
patch_stream(tts_entity)
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
@@ -223,18 +311,20 @@ async def test_tts_service_speak_lang_config(
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
tts_entity._client.text_to_speech.convert.assert_called_once_with(
|
||||
text="There is a person at the front door.",
|
||||
voice_id="voice1",
|
||||
model_id="model1",
|
||||
voice_settings=tts_entity._voice_settings,
|
||||
)
|
||||
assert len(stream_calls) == 1
|
||||
language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language)
|
||||
call_kwargs = stream_calls[0]
|
||||
assert call_kwargs["text"] == "There is a person at the front door."
|
||||
assert call_kwargs["voice_id"] == "voice1"
|
||||
assert call_kwargs["model_id"] == "model1"
|
||||
assert call_kwargs["voice_settings"] == tts_entity._voice_settings
|
||||
assert call_kwargs["output_format"] == "mp3_44100_128"
|
||||
assert call_kwargs["language_code"] == language
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -257,16 +347,16 @@ async def test_tts_service_speak_error(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
capture_stream_calls,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test service call say with http response 400."""
|
||||
stream_calls, set_next_return, patch_stream = capture_stream_calls
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._client.text_to_speech.convert = MagicMock(
|
||||
return_value=FakeAudioGenerator()
|
||||
)
|
||||
tts_entity._client.text_to_speech.convert.side_effect = ApiError
|
||||
patch_stream(tts_entity)
|
||||
set_next_return(chunks=[], request_id=None, error=ApiError())
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
@@ -275,18 +365,20 @@ async def test_tts_service_speak_error(
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
tts_entity._client.text_to_speech.convert.assert_called_once_with(
|
||||
text="There is a person at the front door.",
|
||||
voice_id="voice1",
|
||||
model_id="model1",
|
||||
voice_settings=tts_entity._voice_settings,
|
||||
)
|
||||
assert len(stream_calls) == 1
|
||||
language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language)
|
||||
call_kwargs = stream_calls[0]
|
||||
assert call_kwargs["text"] == "There is a person at the front door."
|
||||
assert call_kwargs["voice_id"] == "voice1"
|
||||
assert call_kwargs["model_id"] == "model1"
|
||||
assert call_kwargs["voice_settings"] == tts_entity._voice_settings
|
||||
assert call_kwargs["output_format"] == "mp3_44100_128"
|
||||
assert call_kwargs["language_code"] == language
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -323,16 +415,17 @@ async def test_tts_service_speak_voice_settings(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
capture_stream_calls,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
mock_similarity: float,
|
||||
) -> None:
|
||||
"""Test tts service."""
|
||||
stream_calls, _, patch_stream = capture_stream_calls
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._client.text_to_speech.convert = MagicMock(
|
||||
return_value=FakeAudioGenerator()
|
||||
)
|
||||
patch_stream(tts_entity)
|
||||
|
||||
assert tts_entity._voice_settings == VoiceSettings(
|
||||
stability=DEFAULT_STABILITY,
|
||||
similarity_boost=DEFAULT_SIMILARITY / 2,
|
||||
@@ -347,18 +440,20 @@ async def test_tts_service_speak_voice_settings(
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
tts_entity._client.text_to_speech.convert.assert_called_once_with(
|
||||
text="There is a person at the front door.",
|
||||
voice_id="voice2",
|
||||
model_id="model1",
|
||||
voice_settings=tts_entity._voice_settings,
|
||||
)
|
||||
assert len(stream_calls) == 1
|
||||
language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language)
|
||||
call_kwargs = stream_calls[0]
|
||||
assert call_kwargs["text"] == "There is a person at the front door."
|
||||
assert call_kwargs["voice_id"] == "voice2"
|
||||
assert call_kwargs["model_id"] == "model1"
|
||||
assert call_kwargs["voice_settings"] == tts_entity._voice_settings
|
||||
assert call_kwargs["output_format"] == "mp3_44100_128"
|
||||
assert call_kwargs["language_code"] == language
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -381,15 +476,15 @@ async def test_tts_service_speak_without_options(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
capture_stream_calls,
|
||||
calls: list[ServiceCall],
|
||||
tts_service: str,
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test service call say with http response 200."""
|
||||
stream_calls, _, patch_stream = capture_stream_calls
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
|
||||
tts_entity._client.text_to_speech.convert = MagicMock(
|
||||
return_value=FakeAudioGenerator()
|
||||
)
|
||||
patch_stream(tts_entity)
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
@@ -398,17 +493,179 @@ async def test_tts_service_speak_without_options(
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
|
||||
tts_entity._client.text_to_speech.convert.assert_called_once_with(
|
||||
text="There is a person at the front door.",
|
||||
voice_id="voice1",
|
||||
voice_settings=VoiceSettings(
|
||||
stability=0.5, similarity_boost=0.75, style=0.0, use_speaker_boost=True
|
||||
),
|
||||
model_id="model1",
|
||||
assert len(stream_calls) == 1
|
||||
language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language)
|
||||
call_kwargs = stream_calls[0]
|
||||
assert call_kwargs["text"] == "There is a person at the front door."
|
||||
assert call_kwargs["voice_id"] == "voice1"
|
||||
assert call_kwargs["model_id"] == "model1"
|
||||
assert call_kwargs["voice_settings"] == VoiceSettings(
|
||||
stability=0.5,
|
||||
similarity_boost=0.75,
|
||||
style=0.0,
|
||||
use_speaker_boost=True,
|
||||
)
|
||||
assert call_kwargs["output_format"] == "mp3_44100_128"
|
||||
assert call_kwargs["language_code"] == language
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "model_id"),
|
||||
[
|
||||
("mock_config_entry_setup", "eleven_multilingual_v2"),
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("message", "chunks", "request_ids"),
|
||||
[
|
||||
(
|
||||
[
|
||||
["One. ", "Two! ", "Three"],
|
||||
["! ", "Four"],
|
||||
["? ", "Five"],
|
||||
["! ", "Six!"],
|
||||
],
|
||||
[b"\x05\x06", b"\x07\x08", b"\x09\x0a", b"\x0b\x0c"],
|
||||
["rid-1", "rid-2", "rid-3", "rid-4"],
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_stream_tts_with_request_ids(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
capture_stream_calls,
|
||||
stream_sentence_helpers,
|
||||
model_id: str,
|
||||
message: list[list[str]],
|
||||
chunks: list[bytes],
|
||||
request_ids: list[str],
|
||||
) -> None:
|
||||
"""Test streaming TTS with request-id stitching."""
|
||||
calls, set_next_return, patch_stream = capture_stream_calls
|
||||
|
||||
# Access the TTS entity as in your existing tests; adjust if you use a fixture instead
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity("tts.elevenlabs_text_to_speech")
|
||||
patch_stream(tts_entity)
|
||||
|
||||
# Use a queue to control when each part is yielded
|
||||
queue = asyncio.Queue()
|
||||
prev_request_ids: deque[str] = deque(maxlen=3) # keep last 3 request IDs
|
||||
sentence_iter = iter(zip(message, chunks, request_ids, strict=False))
|
||||
get_next_part, message_gen = stream_sentence_helpers(sentence_iter, queue)
|
||||
options = {tts.ATTR_VOICE: "voice1", "model": model_id}
|
||||
req = TTSAudioRequest(message_gen=message_gen(), language="en", options=options)
|
||||
|
||||
resp = await tts_entity.async_stream_tts_audio(req)
|
||||
assert resp.extension == "mp3"
|
||||
|
||||
item, chunk, request_id = await get_next_part()
|
||||
if item is not None:
|
||||
for part in item:
|
||||
await queue.put(part)
|
||||
else:
|
||||
await queue.put(None)
|
||||
|
||||
set_next_return(chunks=[chunk], request_id=request_id)
|
||||
next_item, next_chunk, next_request_id = await get_next_part()
|
||||
# Consume bytes; after first chunk, switch next return to emulate second call
|
||||
async for b in resp.data_gen:
|
||||
assert b == chunk # each sentence yields its first chunk immediately
|
||||
assert "previous_text" not in calls[-1] # no previous_text for first sentence
|
||||
assert "next_text" not in calls[-1] # no next_text for first
|
||||
assert calls[-1].get("previous_request_ids", []) == (
|
||||
[] if len(calls) == 1 else list(prev_request_ids)
|
||||
)
|
||||
if request_id:
|
||||
prev_request_ids.append(request_id or "")
|
||||
item, chunk, request_id = next_item, next_chunk, next_request_id
|
||||
if item is not None:
|
||||
for part in item:
|
||||
await queue.put(part)
|
||||
set_next_return(chunks=[chunk], request_id=request_id)
|
||||
next_item, next_chunk, next_request_id = await get_next_part()
|
||||
if item is None:
|
||||
await queue.put(None)
|
||||
else:
|
||||
await queue.put(None)
|
||||
|
||||
# We expect two stream() invocations (one per sentence batch)
|
||||
assert len(calls) == len(message)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("message", "chunks", "request_ids"),
|
||||
[
|
||||
(
|
||||
[
|
||||
["This is the first sentence. ", "This is "],
|
||||
["the second sentence. "],
|
||||
],
|
||||
[b"\x05\x06", b"\x07\x08"],
|
||||
["rid-1", "rid-2"],
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_stream_tts_without_previous_info(
|
||||
setup: AsyncMock,
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
capture_stream_calls,
|
||||
stream_sentence_helpers,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
message: list[list[str]],
|
||||
chunks: list[bytes],
|
||||
request_ids: list[str],
|
||||
) -> None:
|
||||
"""Test streaming TTS without request-id stitching (eleven_v3)."""
|
||||
calls, set_next_return, patch_stream = capture_stream_calls
|
||||
tts_entity = hass.data[tts.DOMAIN].get_entity("tts.elevenlabs_text_to_speech")
|
||||
patch_stream(tts_entity)
|
||||
monkeypatch.setattr(
|
||||
"homeassistant.components.elevenlabs.tts.MODELS_PREVIOUS_INFO_NOT_SUPPORTED",
|
||||
("model1",),
|
||||
raising=False,
|
||||
)
|
||||
|
||||
queue = asyncio.Queue()
|
||||
sentence_iter = iter(zip(message, chunks, request_ids, strict=False))
|
||||
get_next_part, message_gen = stream_sentence_helpers(sentence_iter, queue)
|
||||
options = {tts.ATTR_VOICE: "voice1", "model": "model1"}
|
||||
req = TTSAudioRequest(message_gen=message_gen(), language="en", options=options)
|
||||
|
||||
resp = await tts_entity.async_stream_tts_audio(req)
|
||||
assert resp.extension == "mp3"
|
||||
|
||||
item, chunk, request_id = await get_next_part()
|
||||
if item is not None:
|
||||
for part in item:
|
||||
await queue.put(part)
|
||||
else:
|
||||
await queue.put(None)
|
||||
|
||||
set_next_return(chunks=[chunk], request_id=request_id)
|
||||
next_item, next_chunk, next_request_id = await get_next_part()
|
||||
# Consume bytes; after first chunk, switch next return to emulate second call
|
||||
async for b in resp.data_gen:
|
||||
assert b == chunk # each sentence yields its first chunk immediately
|
||||
assert "previous_request_ids" not in calls[-1] # no previous_request_ids
|
||||
|
||||
item, chunk, request_id = next_item, next_chunk, next_request_id
|
||||
if item is not None:
|
||||
for part in item:
|
||||
await queue.put(part)
|
||||
set_next_return(chunks=[chunk], request_id=request_id)
|
||||
next_item, next_chunk, next_request_id = await get_next_part()
|
||||
if item is None:
|
||||
await queue.put(None)
|
||||
else:
|
||||
await queue.put(None)
|
||||
|
||||
# We expect two stream() invocations (one per sentence batch)
|
||||
assert len(calls) == len(message)
|
||||
|
||||
Reference in New Issue
Block a user