1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-19 18:38:58 +00:00

Add streaming to Elevenlabs TTS (#154663)

This commit is contained in:
ehendrix23
2025-10-18 15:50:01 -06:00
committed by GitHub
parent 7af3eb638b
commit f1e72c1616
9 changed files with 493 additions and 77 deletions

View File

@@ -182,7 +182,6 @@ homeassistant.components.efergy.*
homeassistant.components.eheimdigital.*
homeassistant.components.electrasmart.*
homeassistant.components.electric_kiwi.*
homeassistant.components.elevenlabs.*
homeassistant.components.elgato.*
homeassistant.components.elkm1.*
homeassistant.components.emulated_hue.*

View File

@@ -21,6 +21,9 @@ DEFAULT_STT_MODEL = "scribe_v1"
DEFAULT_STYLE = 0
DEFAULT_USE_SPEAKER_BOOST = True
MAX_REQUEST_IDS = 3
MODELS_PREVIOUS_INFO_NOT_SUPPORTED = ("eleven_v3",)
STT_LANGUAGES = [
"af-ZA", # Afrikaans
"am-ET", # Amharic

View File

@@ -7,5 +7,5 @@
"integration_type": "service",
"iot_class": "cloud_polling",
"loggers": ["elevenlabs"],
"requirements": ["elevenlabs==2.3.0"]
"requirements": ["elevenlabs==2.3.0", "sentence-stream==1.2.0"]
}

View File

@@ -85,4 +85,4 @@ rules:
# Platinum
async-dependency: done
inject-websession: done
strict-typing: done
strict-typing: todo

View File

@@ -2,17 +2,23 @@
from __future__ import annotations
from collections.abc import Mapping
import asyncio
from collections import deque
from collections.abc import AsyncGenerator, Mapping
import contextlib
import logging
from typing import Any
from elevenlabs import AsyncElevenLabs
from elevenlabs.core import ApiError
from elevenlabs.types import Model, Voice as ElevenLabsVoice, VoiceSettings
from sentence_stream import SentenceBoundaryDetector
from homeassistant.components.tts import (
ATTR_VOICE,
TextToSpeechEntity,
TTSAudioRequest,
TTSAudioResponse,
TtsAudioType,
Voice,
)
@@ -35,10 +41,12 @@ from .const import (
DEFAULT_STYLE,
DEFAULT_USE_SPEAKER_BOOST,
DOMAIN,
MAX_REQUEST_IDS,
MODELS_PREVIOUS_INFO_NOT_SUPPORTED,
)
_LOGGER = logging.getLogger(__name__)
PARALLEL_UPDATES = 0
PARALLEL_UPDATES = 6
def to_voice_settings(options: Mapping[str, Any]) -> VoiceSettings:
@@ -122,7 +130,12 @@ class ElevenLabsTTSEntity(TextToSpeechEntity):
self._attr_supported_languages = [
lang.language_id for lang in self._model.languages or []
]
self._attr_default_language = self.supported_languages[0]
# Use the first supported language as the default if available
self._attr_default_language = (
self._attr_supported_languages[0]
if self._attr_supported_languages
else "en"
)
def async_get_supported_voices(self, language: str) -> list[Voice]:
"""Return a list of supported voices for a language."""
@@ -151,3 +164,151 @@ class ElevenLabsTTSEntity(TextToSpeechEntity):
)
raise HomeAssistantError(exc) from exc
return "mp3", bytes_combined
async def async_stream_tts_audio(
self, request: TTSAudioRequest
) -> TTSAudioResponse:
"""Generate speech from an incoming message."""
_LOGGER.debug(
"Getting TTS audio for language %s and options: %s",
request.language,
request.options,
)
return TTSAudioResponse("mp3", self._process_tts_stream(request))
async def _process_tts_stream(
self, request: TTSAudioRequest
) -> AsyncGenerator[bytes]:
"""Generate speech from an incoming message."""
text_stream = request.message_gen
boundary_detector = SentenceBoundaryDetector()
sentences: list[str] = []
sentences_ready = asyncio.Event()
sentences_complete = False
language_code: str | None = request.language
voice_id = request.options.get(ATTR_VOICE, self._default_voice_id)
model = request.options.get(ATTR_MODEL, self._model.model_id)
use_request_ids = model not in MODELS_PREVIOUS_INFO_NOT_SUPPORTED
previous_request_ids: deque[str] = deque(maxlen=MAX_REQUEST_IDS)
base_stream_params = {
"voice_id": voice_id,
"model_id": model,
"output_format": "mp3_44100_128",
"voice_settings": self._voice_settings,
}
if language_code:
base_stream_params["language_code"] = language_code
_LOGGER.debug("Starting TTS Stream with options: %s", base_stream_params)
async def _add_sentences() -> None:
nonlocal sentences_complete
try:
# Text chunks may not be on word or sentence boundaries
async for text_chunk in text_stream:
for sentence in boundary_detector.add_chunk(text_chunk):
if not sentence.strip():
continue
sentences.append(sentence)
if not sentences:
continue
sentences_ready.set()
# Final sentence
if text := boundary_detector.finish():
sentences.append(text)
finally:
sentences_complete = True
sentences_ready.set()
_add_sentences_task = self.hass.async_create_background_task(
_add_sentences(), name="elevenlabs_tts_add_sentences"
)
# Process new sentences as they're available, but synthesize the first
# one immediately. While that's playing, synthesize (up to) the next 3
# sentences. After that, synthesize all completed sentences as they're
# available.
sentence_schedule = [1, 3]
while True:
await sentences_ready.wait()
# Don't wait again if no more sentences are coming
if not sentences_complete:
sentences_ready.clear()
if not sentences:
if sentences_complete:
# Exit TTS loop
_LOGGER.debug("No more sentences to process")
break
# More sentences may be coming
continue
new_sentences = sentences[:]
sentences.clear()
while new_sentences:
if sentence_schedule:
max_sentences = sentence_schedule.pop(0)
sentences_to_process = new_sentences[:max_sentences]
new_sentences = new_sentences[len(sentences_to_process) :]
else:
# Process all available sentences together
sentences_to_process = new_sentences[:]
new_sentences.clear()
# Combine all new sentences completed to this point
text = " ".join(sentences_to_process).strip()
if not text:
continue
# Build kwargs common to both modes
kwargs = base_stream_params | {
"text": text,
}
# Provide previous_request_ids if supported.
if previous_request_ids:
# Send previous request ids.
kwargs["previous_request_ids"] = list(previous_request_ids)
# Synthesize audio while text chunks are still being accumulated
_LOGGER.debug("Synthesizing TTS for text: %s", text)
try:
async with self._client.text_to_speech.with_raw_response.stream(
**kwargs
) as stream:
async for chunk_bytes in stream.data:
yield chunk_bytes
if use_request_ids:
if (rid := stream.headers.get("request-id")) is not None:
previous_request_ids.append(rid)
else:
_LOGGER.debug(
"No request-id returned from server; clearing previous requests"
)
previous_request_ids.clear()
except ApiError as exc:
_LOGGER.warning(
"Error during processing of TTS request %s", exc, exc_info=True
)
_add_sentences_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await _add_sentences_task
raise HomeAssistantError(exc) from exc
# Capture and store server request-id for next calls (only when supported)
_LOGGER.debug("Completed TTS stream for text: %s", text)
_LOGGER.debug("Completed TTS stream")

10
mypy.ini generated
View File

@@ -1576,16 +1576,6 @@ disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.elevenlabs.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.elgato.*]
check_untyped_defs = true
disallow_incomplete_defs = true

3
requirements_all.txt generated
View File

@@ -2812,6 +2812,9 @@ sensorpush-ha==1.3.2
# homeassistant.components.sensoterra
sensoterra==2.0.1
# homeassistant.components.elevenlabs
sentence-stream==1.2.0
# homeassistant.components.sentry
sentry-sdk==1.45.1

View File

@@ -2337,6 +2337,9 @@ sensorpush-ha==1.3.2
# homeassistant.components.sensoterra
sensoterra==2.0.1
# homeassistant.components.elevenlabs
sentence-stream==1.2.0
# homeassistant.components.sentry
sentry-sdk==1.45.1

View File

@@ -2,9 +2,12 @@
from __future__ import annotations
import asyncio
from collections import deque
from collections.abc import AsyncIterator, Iterator
from http import HTTPStatus
from pathlib import Path
from typing import Any
from typing import Any, Self
from unittest.mock import AsyncMock, MagicMock
from elevenlabs.core import ApiError
@@ -28,6 +31,7 @@ from homeassistant.components.media_player import (
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
)
from homeassistant.components.tts import TTSAudioRequest
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.core_config import async_process_ha_core_config
@@ -37,17 +41,99 @@ from tests.components.tts.common import retrieve_media
from tests.typing import ClientSessionGenerator
class FakeAudioGenerator:
"""Mock audio generator for ElevenLabs TTS."""
class _FakeResponse:
def __init__(self, headers: dict[str, str]) -> None:
self.headers = headers
def __aiter__(self):
"""Mock async iterator for audio parts."""
async def _gen():
yield b"audio-part-1"
yield b"audio-part-2"
class _AsyncByteStream:
"""Async iterator that yields bytes and exposes response headers like ElevenLabs' stream."""
return _gen()
def __init__(self, chunks: list[bytes], request_id: str | None = None) -> None:
self._chunks = chunks
self._i = 0
def __aiter__(self) -> AsyncIterator[bytes]:
return self
async def __anext__(self) -> bytes:
if self._i >= len(self._chunks):
raise StopAsyncIteration
b = self._chunks[self._i]
self._i += 1
await asyncio.sleep(0) # let loop breathe; mirrors real async iterator
return b
class _AsyncStreamResponse:
"""Async context manager that mimics ElevenLabs raw stream responses."""
def __init__(self, chunks: list[bytes], request_id: str | None = None) -> None:
self.headers = {"request-id": request_id} if request_id else {}
self.data = _AsyncByteStream(chunks)
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
return None
@pytest.fixture
def capture_stream_calls(monkeypatch: pytest.MonkeyPatch):
"""Patches AsyncElevenLabs.text_to_speech.with_raw_response.stream and captures each call's kwargs.
Returns:
calls: list[dict] — kwargs passed into each stream() invocation
set_next_return(chunks, request_id): sets what the NEXT stream() call yields/returns
"""
calls: list[dict] = []
state = {"chunks": [b"X"], "request_id": "rid-1"} # defaults; override per test
def set_next_return(
*, chunks: list[bytes], request_id: str | None, error: Exception | None = None
) -> None:
state["chunks"] = chunks
state["request_id"] = request_id
state["error"] = error
def patch_stream(tts_entity):
def _mock_stream(**kwargs):
calls.append(kwargs)
if state.get("error") is not None:
raise state["error"]
return _AsyncStreamResponse(
chunks=list(state["chunks"]),
request_id=state["request_id"],
)
tts_entity._client.text_to_speech.with_raw_response.stream = _mock_stream
return calls, set_next_return, patch_stream
@pytest.fixture
def stream_sentence_helpers():
"""Return helpers for queue-driven sentence streaming."""
def factory(sentence_iter: Iterator[tuple], queue: asyncio.Queue[str | None]):
async def get_next_part() -> tuple[Any, ...]:
try:
return next(sentence_iter)
except StopIteration:
await queue.put(None)
return None, None, None
async def message_gen() -> AsyncIterator[str]:
while True:
part = await queue.get()
if part is None:
break
yield part
return get_next_part, message_gen
return factory
@pytest.fixture(autouse=True)
@@ -134,15 +220,15 @@ async def test_tts_service_speak(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
capture_stream_calls,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test tts service."""
stream_calls, _, patch_stream = capture_stream_calls
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
tts_entity._client.text_to_speech.convert = MagicMock(
return_value=FakeAudioGenerator()
)
patch_stream(tts_entity)
assert tts_entity._voice_settings == VoiceSettings(
stability=DEFAULT_STABILITY,
@@ -158,20 +244,22 @@ async def test_tts_service_speak(
blocking=True,
)
assert len(calls) == 1
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
assert len(stream_calls) == 1
voice_id = service_data[tts.ATTR_OPTIONS].get(tts.ATTR_VOICE, "voice1")
model_id = service_data[tts.ATTR_OPTIONS].get(ATTR_MODEL, "model1")
language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language)
tts_entity._client.text_to_speech.convert.assert_called_once_with(
text="There is a person at the front door.",
voice_id=voice_id,
model_id=model_id,
voice_settings=tts_entity._voice_settings,
)
call_kwargs = stream_calls[0]
assert call_kwargs["text"] == "There is a person at the front door."
assert call_kwargs["voice_id"] == voice_id
assert call_kwargs["model_id"] == model_id
assert call_kwargs["voice_settings"] == tts_entity._voice_settings
assert call_kwargs["output_format"] == "mp3_44100_128"
assert call_kwargs["language_code"] == language
@pytest.mark.parametrize(
@@ -206,15 +294,15 @@ async def test_tts_service_speak_lang_config(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
capture_stream_calls,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test service call say with other langcodes in the config."""
stream_calls, _, patch_stream = capture_stream_calls
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
tts_entity._client.text_to_speech.convert = MagicMock(
return_value=FakeAudioGenerator()
)
patch_stream(tts_entity)
await hass.services.async_call(
tts.DOMAIN,
@@ -223,18 +311,20 @@ async def test_tts_service_speak_lang_config(
blocking=True,
)
assert len(calls) == 1
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
tts_entity._client.text_to_speech.convert.assert_called_once_with(
text="There is a person at the front door.",
voice_id="voice1",
model_id="model1",
voice_settings=tts_entity._voice_settings,
)
assert len(stream_calls) == 1
language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language)
call_kwargs = stream_calls[0]
assert call_kwargs["text"] == "There is a person at the front door."
assert call_kwargs["voice_id"] == "voice1"
assert call_kwargs["model_id"] == "model1"
assert call_kwargs["voice_settings"] == tts_entity._voice_settings
assert call_kwargs["output_format"] == "mp3_44100_128"
assert call_kwargs["language_code"] == language
@pytest.mark.parametrize(
@@ -257,16 +347,16 @@ async def test_tts_service_speak_error(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
capture_stream_calls,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test service call say with http response 400."""
stream_calls, set_next_return, patch_stream = capture_stream_calls
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
tts_entity._client.text_to_speech.convert = MagicMock(
return_value=FakeAudioGenerator()
)
tts_entity._client.text_to_speech.convert.side_effect = ApiError
patch_stream(tts_entity)
set_next_return(chunks=[], request_id=None, error=ApiError())
await hass.services.async_call(
tts.DOMAIN,
@@ -275,18 +365,20 @@ async def test_tts_service_speak_error(
blocking=True,
)
assert len(calls) == 1
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.INTERNAL_SERVER_ERROR
)
tts_entity._client.text_to_speech.convert.assert_called_once_with(
text="There is a person at the front door.",
voice_id="voice1",
model_id="model1",
voice_settings=tts_entity._voice_settings,
)
assert len(stream_calls) == 1
language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language)
call_kwargs = stream_calls[0]
assert call_kwargs["text"] == "There is a person at the front door."
assert call_kwargs["voice_id"] == "voice1"
assert call_kwargs["model_id"] == "model1"
assert call_kwargs["voice_settings"] == tts_entity._voice_settings
assert call_kwargs["output_format"] == "mp3_44100_128"
assert call_kwargs["language_code"] == language
@pytest.mark.parametrize(
@@ -323,16 +415,17 @@ async def test_tts_service_speak_voice_settings(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
capture_stream_calls,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
mock_similarity: float,
) -> None:
"""Test tts service."""
stream_calls, _, patch_stream = capture_stream_calls
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
tts_entity._client.text_to_speech.convert = MagicMock(
return_value=FakeAudioGenerator()
)
patch_stream(tts_entity)
assert tts_entity._voice_settings == VoiceSettings(
stability=DEFAULT_STABILITY,
similarity_boost=DEFAULT_SIMILARITY / 2,
@@ -347,18 +440,20 @@ async def test_tts_service_speak_voice_settings(
blocking=True,
)
assert len(calls) == 1
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
tts_entity._client.text_to_speech.convert.assert_called_once_with(
text="There is a person at the front door.",
voice_id="voice2",
model_id="model1",
voice_settings=tts_entity._voice_settings,
)
assert len(stream_calls) == 1
language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language)
call_kwargs = stream_calls[0]
assert call_kwargs["text"] == "There is a person at the front door."
assert call_kwargs["voice_id"] == "voice2"
assert call_kwargs["model_id"] == "model1"
assert call_kwargs["voice_settings"] == tts_entity._voice_settings
assert call_kwargs["output_format"] == "mp3_44100_128"
assert call_kwargs["language_code"] == language
@pytest.mark.parametrize(
@@ -381,15 +476,15 @@ async def test_tts_service_speak_without_options(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
capture_stream_calls,
calls: list[ServiceCall],
tts_service: str,
service_data: dict[str, Any],
) -> None:
"""Test service call say with http response 200."""
stream_calls, _, patch_stream = capture_stream_calls
tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID])
tts_entity._client.text_to_speech.convert = MagicMock(
return_value=FakeAudioGenerator()
)
patch_stream(tts_entity)
await hass.services.async_call(
tts.DOMAIN,
@@ -398,17 +493,179 @@ async def test_tts_service_speak_without_options(
blocking=True,
)
assert len(calls) == 1
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
tts_entity._client.text_to_speech.convert.assert_called_once_with(
text="There is a person at the front door.",
voice_id="voice1",
voice_settings=VoiceSettings(
stability=0.5, similarity_boost=0.75, style=0.0, use_speaker_boost=True
),
model_id="model1",
assert len(stream_calls) == 1
language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language)
call_kwargs = stream_calls[0]
assert call_kwargs["text"] == "There is a person at the front door."
assert call_kwargs["voice_id"] == "voice1"
assert call_kwargs["model_id"] == "model1"
assert call_kwargs["voice_settings"] == VoiceSettings(
stability=0.5,
similarity_boost=0.75,
style=0.0,
use_speaker_boost=True,
)
assert call_kwargs["output_format"] == "mp3_44100_128"
assert call_kwargs["language_code"] == language
@pytest.mark.parametrize(
("setup", "model_id"),
[
("mock_config_entry_setup", "eleven_multilingual_v2"),
],
indirect=["setup"],
)
@pytest.mark.parametrize(
("message", "chunks", "request_ids"),
[
(
[
["One. ", "Two! ", "Three"],
["! ", "Four"],
["? ", "Five"],
["! ", "Six!"],
],
[b"\x05\x06", b"\x07\x08", b"\x09\x0a", b"\x0b\x0c"],
["rid-1", "rid-2", "rid-3", "rid-4"],
),
],
)
async def test_stream_tts_with_request_ids(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
capture_stream_calls,
stream_sentence_helpers,
model_id: str,
message: list[list[str]],
chunks: list[bytes],
request_ids: list[str],
) -> None:
"""Test streaming TTS with request-id stitching."""
calls, set_next_return, patch_stream = capture_stream_calls
# Access the TTS entity as in your existing tests; adjust if you use a fixture instead
tts_entity = hass.data[tts.DOMAIN].get_entity("tts.elevenlabs_text_to_speech")
patch_stream(tts_entity)
# Use a queue to control when each part is yielded
queue = asyncio.Queue()
prev_request_ids: deque[str] = deque(maxlen=3) # keep last 3 request IDs
sentence_iter = iter(zip(message, chunks, request_ids, strict=False))
get_next_part, message_gen = stream_sentence_helpers(sentence_iter, queue)
options = {tts.ATTR_VOICE: "voice1", "model": model_id}
req = TTSAudioRequest(message_gen=message_gen(), language="en", options=options)
resp = await tts_entity.async_stream_tts_audio(req)
assert resp.extension == "mp3"
item, chunk, request_id = await get_next_part()
if item is not None:
for part in item:
await queue.put(part)
else:
await queue.put(None)
set_next_return(chunks=[chunk], request_id=request_id)
next_item, next_chunk, next_request_id = await get_next_part()
# Consume bytes; after first chunk, switch next return to emulate second call
async for b in resp.data_gen:
assert b == chunk # each sentence yields its first chunk immediately
assert "previous_text" not in calls[-1] # no previous_text for first sentence
assert "next_text" not in calls[-1] # no next_text for first
assert calls[-1].get("previous_request_ids", []) == (
[] if len(calls) == 1 else list(prev_request_ids)
)
if request_id:
prev_request_ids.append(request_id or "")
item, chunk, request_id = next_item, next_chunk, next_request_id
if item is not None:
for part in item:
await queue.put(part)
set_next_return(chunks=[chunk], request_id=request_id)
next_item, next_chunk, next_request_id = await get_next_part()
if item is None:
await queue.put(None)
else:
await queue.put(None)
# We expect two stream() invocations (one per sentence batch)
assert len(calls) == len(message)
@pytest.mark.parametrize(
("message", "chunks", "request_ids"),
[
(
[
["This is the first sentence. ", "This is "],
["the second sentence. "],
],
[b"\x05\x06", b"\x07\x08"],
["rid-1", "rid-2"],
),
],
)
async def test_stream_tts_without_previous_info(
setup: AsyncMock,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
capture_stream_calls,
stream_sentence_helpers,
monkeypatch: pytest.MonkeyPatch,
message: list[list[str]],
chunks: list[bytes],
request_ids: list[str],
) -> None:
"""Test streaming TTS without request-id stitching (eleven_v3)."""
calls, set_next_return, patch_stream = capture_stream_calls
tts_entity = hass.data[tts.DOMAIN].get_entity("tts.elevenlabs_text_to_speech")
patch_stream(tts_entity)
monkeypatch.setattr(
"homeassistant.components.elevenlabs.tts.MODELS_PREVIOUS_INFO_NOT_SUPPORTED",
("model1",),
raising=False,
)
queue = asyncio.Queue()
sentence_iter = iter(zip(message, chunks, request_ids, strict=False))
get_next_part, message_gen = stream_sentence_helpers(sentence_iter, queue)
options = {tts.ATTR_VOICE: "voice1", "model": "model1"}
req = TTSAudioRequest(message_gen=message_gen(), language="en", options=options)
resp = await tts_entity.async_stream_tts_audio(req)
assert resp.extension == "mp3"
item, chunk, request_id = await get_next_part()
if item is not None:
for part in item:
await queue.put(part)
else:
await queue.put(None)
set_next_return(chunks=[chunk], request_id=request_id)
next_item, next_chunk, next_request_id = await get_next_part()
# Consume bytes; after first chunk, switch next return to emulate second call
async for b in resp.data_gen:
assert b == chunk # each sentence yields its first chunk immediately
assert "previous_request_ids" not in calls[-1] # no previous_request_ids
item, chunk, request_id = next_item, next_chunk, next_request_id
if item is not None:
for part in item:
await queue.put(part)
set_next_return(chunks=[chunk], request_id=request_id)
next_item, next_chunk, next_request_id = await get_next_part()
if item is None:
await queue.put(None)
else:
await queue.put(None)
# We expect two stream() invocations (one per sentence batch)
assert len(calls) == len(message)