diff --git a/.strict-typing b/.strict-typing index e125deb7cac..91aa90df028 100644 --- a/.strict-typing +++ b/.strict-typing @@ -182,7 +182,6 @@ homeassistant.components.efergy.* homeassistant.components.eheimdigital.* homeassistant.components.electrasmart.* homeassistant.components.electric_kiwi.* -homeassistant.components.elevenlabs.* homeassistant.components.elgato.* homeassistant.components.elkm1.* homeassistant.components.emulated_hue.* diff --git a/homeassistant/components/elevenlabs/const.py b/homeassistant/components/elevenlabs/const.py index 5d7aab7dbb6..c424e0a4588 100644 --- a/homeassistant/components/elevenlabs/const.py +++ b/homeassistant/components/elevenlabs/const.py @@ -21,6 +21,9 @@ DEFAULT_STT_MODEL = "scribe_v1" DEFAULT_STYLE = 0 DEFAULT_USE_SPEAKER_BOOST = True +MAX_REQUEST_IDS = 3 +MODELS_PREVIOUS_INFO_NOT_SUPPORTED = ("eleven_v3",) + STT_LANGUAGES = [ "af-ZA", # Afrikaans "am-ET", # Amharic diff --git a/homeassistant/components/elevenlabs/manifest.json b/homeassistant/components/elevenlabs/manifest.json index f36a2383576..36d5b6aa3aa 100644 --- a/homeassistant/components/elevenlabs/manifest.json +++ b/homeassistant/components/elevenlabs/manifest.json @@ -7,5 +7,5 @@ "integration_type": "service", "iot_class": "cloud_polling", "loggers": ["elevenlabs"], - "requirements": ["elevenlabs==2.3.0"] + "requirements": ["elevenlabs==2.3.0", "sentence-stream==1.2.0"] } diff --git a/homeassistant/components/elevenlabs/quality_scale.yaml b/homeassistant/components/elevenlabs/quality_scale.yaml index 94c395310c5..99658e555a8 100644 --- a/homeassistant/components/elevenlabs/quality_scale.yaml +++ b/homeassistant/components/elevenlabs/quality_scale.yaml @@ -85,4 +85,4 @@ rules: # Platinum async-dependency: done inject-websession: done - strict-typing: done + strict-typing: todo diff --git a/homeassistant/components/elevenlabs/tts.py b/homeassistant/components/elevenlabs/tts.py index 21da81cef6f..b1c26093cf9 100644 --- a/homeassistant/components/elevenlabs/tts.py +++ b/homeassistant/components/elevenlabs/tts.py @@ -2,17 +2,23 @@ from __future__ import annotations -from collections.abc import Mapping +import asyncio +from collections import deque +from collections.abc import AsyncGenerator, Mapping +import contextlib import logging from typing import Any from elevenlabs import AsyncElevenLabs from elevenlabs.core import ApiError from elevenlabs.types import Model, Voice as ElevenLabsVoice, VoiceSettings +from sentence_stream import SentenceBoundaryDetector from homeassistant.components.tts import ( ATTR_VOICE, TextToSpeechEntity, + TTSAudioRequest, + TTSAudioResponse, TtsAudioType, Voice, ) @@ -35,10 +41,12 @@ from .const import ( DEFAULT_STYLE, DEFAULT_USE_SPEAKER_BOOST, DOMAIN, + MAX_REQUEST_IDS, + MODELS_PREVIOUS_INFO_NOT_SUPPORTED, ) _LOGGER = logging.getLogger(__name__) -PARALLEL_UPDATES = 0 +PARALLEL_UPDATES = 6 def to_voice_settings(options: Mapping[str, Any]) -> VoiceSettings: @@ -122,7 +130,12 @@ class ElevenLabsTTSEntity(TextToSpeechEntity): self._attr_supported_languages = [ lang.language_id for lang in self._model.languages or [] ] - self._attr_default_language = self.supported_languages[0] + # Use the first supported language as the default if available + self._attr_default_language = ( + self._attr_supported_languages[0] + if self._attr_supported_languages + else "en" + ) def async_get_supported_voices(self, language: str) -> list[Voice]: """Return a list of supported voices for a language.""" @@ -151,3 +164,151 @@ class ElevenLabsTTSEntity(TextToSpeechEntity): ) raise HomeAssistantError(exc) from exc return "mp3", bytes_combined + + async def async_stream_tts_audio( + self, request: TTSAudioRequest + ) -> TTSAudioResponse: + """Generate speech from an incoming message.""" + _LOGGER.debug( + "Getting TTS audio for language %s and options: %s", + request.language, + request.options, + ) + return TTSAudioResponse("mp3", self._process_tts_stream(request)) + + async def _process_tts_stream( + self, request: TTSAudioRequest + ) -> AsyncGenerator[bytes]: + """Generate speech from an incoming message.""" + text_stream = request.message_gen + boundary_detector = SentenceBoundaryDetector() + sentences: list[str] = [] + sentences_ready = asyncio.Event() + sentences_complete = False + + language_code: str | None = request.language + voice_id = request.options.get(ATTR_VOICE, self._default_voice_id) + model = request.options.get(ATTR_MODEL, self._model.model_id) + + use_request_ids = model not in MODELS_PREVIOUS_INFO_NOT_SUPPORTED + previous_request_ids: deque[str] = deque(maxlen=MAX_REQUEST_IDS) + + base_stream_params = { + "voice_id": voice_id, + "model_id": model, + "output_format": "mp3_44100_128", + "voice_settings": self._voice_settings, + } + if language_code: + base_stream_params["language_code"] = language_code + + _LOGGER.debug("Starting TTS Stream with options: %s", base_stream_params) + + async def _add_sentences() -> None: + nonlocal sentences_complete + + try: + # Text chunks may not be on word or sentence boundaries + async for text_chunk in text_stream: + for sentence in boundary_detector.add_chunk(text_chunk): + if not sentence.strip(): + continue + + sentences.append(sentence) + + if not sentences: + continue + + sentences_ready.set() + + # Final sentence + if text := boundary_detector.finish(): + sentences.append(text) + finally: + sentences_complete = True + sentences_ready.set() + + _add_sentences_task = self.hass.async_create_background_task( + _add_sentences(), name="elevenlabs_tts_add_sentences" + ) + + # Process new sentences as they're available, but synthesize the first + # one immediately. While that's playing, synthesize (up to) the next 3 + # sentences. After that, synthesize all completed sentences as they're + # available. + sentence_schedule = [1, 3] + while True: + await sentences_ready.wait() + + # Don't wait again if no more sentences are coming + if not sentences_complete: + sentences_ready.clear() + + if not sentences: + if sentences_complete: + # Exit TTS loop + _LOGGER.debug("No more sentences to process") + break + + # More sentences may be coming + continue + + new_sentences = sentences[:] + sentences.clear() + + while new_sentences: + if sentence_schedule: + max_sentences = sentence_schedule.pop(0) + sentences_to_process = new_sentences[:max_sentences] + new_sentences = new_sentences[len(sentences_to_process) :] + else: + # Process all available sentences together + sentences_to_process = new_sentences[:] + new_sentences.clear() + + # Combine all new sentences completed to this point + text = " ".join(sentences_to_process).strip() + + if not text: + continue + + # Build kwargs common to both modes + kwargs = base_stream_params | { + "text": text, + } + + # Provide previous_request_ids if supported. + if previous_request_ids: + # Send previous request ids. + kwargs["previous_request_ids"] = list(previous_request_ids) + + # Synthesize audio while text chunks are still being accumulated + _LOGGER.debug("Synthesizing TTS for text: %s", text) + try: + async with self._client.text_to_speech.with_raw_response.stream( + **kwargs + ) as stream: + async for chunk_bytes in stream.data: + yield chunk_bytes + + if use_request_ids: + if (rid := stream.headers.get("request-id")) is not None: + previous_request_ids.append(rid) + else: + _LOGGER.debug( + "No request-id returned from server; clearing previous requests" + ) + previous_request_ids.clear() + except ApiError as exc: + _LOGGER.warning( + "Error during processing of TTS request %s", exc, exc_info=True + ) + _add_sentences_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await _add_sentences_task + raise HomeAssistantError(exc) from exc + + # Capture and store server request-id for next calls (only when supported) + _LOGGER.debug("Completed TTS stream for text: %s", text) + + _LOGGER.debug("Completed TTS stream") diff --git a/mypy.ini b/mypy.ini index f7a36041fa9..3f987800262 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1576,16 +1576,6 @@ disallow_untyped_defs = true warn_return_any = true warn_unreachable = true -[mypy-homeassistant.components.elevenlabs.*] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -warn_return_any = true -warn_unreachable = true - [mypy-homeassistant.components.elgato.*] check_untyped_defs = true disallow_incomplete_defs = true diff --git a/requirements_all.txt b/requirements_all.txt index fcd51da3f48..57fc8370468 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2812,6 +2812,9 @@ sensorpush-ha==1.3.2 # homeassistant.components.sensoterra sensoterra==2.0.1 +# homeassistant.components.elevenlabs +sentence-stream==1.2.0 + # homeassistant.components.sentry sentry-sdk==1.45.1 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 605a2745bd7..fde780d62db 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -2337,6 +2337,9 @@ sensorpush-ha==1.3.2 # homeassistant.components.sensoterra sensoterra==2.0.1 +# homeassistant.components.elevenlabs +sentence-stream==1.2.0 + # homeassistant.components.sentry sentry-sdk==1.45.1 diff --git a/tests/components/elevenlabs/test_tts.py b/tests/components/elevenlabs/test_tts.py index c5e7529e5a0..08ab1b6fab0 100644 --- a/tests/components/elevenlabs/test_tts.py +++ b/tests/components/elevenlabs/test_tts.py @@ -2,9 +2,12 @@ from __future__ import annotations +import asyncio +from collections import deque +from collections.abc import AsyncIterator, Iterator from http import HTTPStatus from pathlib import Path -from typing import Any +from typing import Any, Self from unittest.mock import AsyncMock, MagicMock from elevenlabs.core import ApiError @@ -28,6 +31,7 @@ from homeassistant.components.media_player import ( DOMAIN as DOMAIN_MP, SERVICE_PLAY_MEDIA, ) +from homeassistant.components.tts import TTSAudioRequest from homeassistant.const import ATTR_ENTITY_ID from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.core_config import async_process_ha_core_config @@ -37,17 +41,99 @@ from tests.components.tts.common import retrieve_media from tests.typing import ClientSessionGenerator -class FakeAudioGenerator: - """Mock audio generator for ElevenLabs TTS.""" +class _FakeResponse: + def __init__(self, headers: dict[str, str]) -> None: + self.headers = headers - def __aiter__(self): - """Mock async iterator for audio parts.""" - async def _gen(): - yield b"audio-part-1" - yield b"audio-part-2" +class _AsyncByteStream: + """Async iterator that yields bytes and exposes response headers like ElevenLabs' stream.""" - return _gen() + def __init__(self, chunks: list[bytes], request_id: str | None = None) -> None: + self._chunks = chunks + self._i = 0 + + def __aiter__(self) -> AsyncIterator[bytes]: + return self + + async def __anext__(self) -> bytes: + if self._i >= len(self._chunks): + raise StopAsyncIteration + b = self._chunks[self._i] + self._i += 1 + await asyncio.sleep(0) # let loop breathe; mirrors real async iterator + return b + + +class _AsyncStreamResponse: + """Async context manager that mimics ElevenLabs raw stream responses.""" + + def __init__(self, chunks: list[bytes], request_id: str | None = None) -> None: + self.headers = {"request-id": request_id} if request_id else {} + self.data = _AsyncByteStream(chunks) + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + return None + + +@pytest.fixture +def capture_stream_calls(monkeypatch: pytest.MonkeyPatch): + """Patches AsyncElevenLabs.text_to_speech.with_raw_response.stream and captures each call's kwargs. + + Returns: + calls: list[dict] — kwargs passed into each stream() invocation + set_next_return(chunks, request_id): sets what the NEXT stream() call yields/returns + """ + calls: list[dict] = [] + state = {"chunks": [b"X"], "request_id": "rid-1"} # defaults; override per test + + def set_next_return( + *, chunks: list[bytes], request_id: str | None, error: Exception | None = None + ) -> None: + state["chunks"] = chunks + state["request_id"] = request_id + state["error"] = error + + def patch_stream(tts_entity): + def _mock_stream(**kwargs): + calls.append(kwargs) + if state.get("error") is not None: + raise state["error"] + return _AsyncStreamResponse( + chunks=list(state["chunks"]), + request_id=state["request_id"], + ) + + tts_entity._client.text_to_speech.with_raw_response.stream = _mock_stream + + return calls, set_next_return, patch_stream + + +@pytest.fixture +def stream_sentence_helpers(): + """Return helpers for queue-driven sentence streaming.""" + + def factory(sentence_iter: Iterator[tuple], queue: asyncio.Queue[str | None]): + async def get_next_part() -> tuple[Any, ...]: + try: + return next(sentence_iter) + except StopIteration: + await queue.put(None) + return None, None, None + + async def message_gen() -> AsyncIterator[str]: + while True: + part = await queue.get() + if part is None: + break + yield part + + return get_next_part, message_gen + + return factory @pytest.fixture(autouse=True) @@ -134,15 +220,15 @@ async def test_tts_service_speak( setup: AsyncMock, hass: HomeAssistant, hass_client: ClientSessionGenerator, + capture_stream_calls, calls: list[ServiceCall], tts_service: str, service_data: dict[str, Any], ) -> None: """Test tts service.""" + stream_calls, _, patch_stream = capture_stream_calls tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID]) - tts_entity._client.text_to_speech.convert = MagicMock( - return_value=FakeAudioGenerator() - ) + patch_stream(tts_entity) assert tts_entity._voice_settings == VoiceSettings( stability=DEFAULT_STABILITY, @@ -158,20 +244,22 @@ async def test_tts_service_speak( blocking=True, ) - assert len(calls) == 1 assert ( await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) == HTTPStatus.OK ) + assert len(stream_calls) == 1 voice_id = service_data[tts.ATTR_OPTIONS].get(tts.ATTR_VOICE, "voice1") model_id = service_data[tts.ATTR_OPTIONS].get(ATTR_MODEL, "model1") + language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language) - tts_entity._client.text_to_speech.convert.assert_called_once_with( - text="There is a person at the front door.", - voice_id=voice_id, - model_id=model_id, - voice_settings=tts_entity._voice_settings, - ) + call_kwargs = stream_calls[0] + assert call_kwargs["text"] == "There is a person at the front door." + assert call_kwargs["voice_id"] == voice_id + assert call_kwargs["model_id"] == model_id + assert call_kwargs["voice_settings"] == tts_entity._voice_settings + assert call_kwargs["output_format"] == "mp3_44100_128" + assert call_kwargs["language_code"] == language @pytest.mark.parametrize( @@ -206,15 +294,15 @@ async def test_tts_service_speak_lang_config( setup: AsyncMock, hass: HomeAssistant, hass_client: ClientSessionGenerator, + capture_stream_calls, calls: list[ServiceCall], tts_service: str, service_data: dict[str, Any], ) -> None: """Test service call say with other langcodes in the config.""" + stream_calls, _, patch_stream = capture_stream_calls tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID]) - tts_entity._client.text_to_speech.convert = MagicMock( - return_value=FakeAudioGenerator() - ) + patch_stream(tts_entity) await hass.services.async_call( tts.DOMAIN, @@ -223,18 +311,20 @@ async def test_tts_service_speak_lang_config( blocking=True, ) - assert len(calls) == 1 assert ( await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) == HTTPStatus.OK ) - tts_entity._client.text_to_speech.convert.assert_called_once_with( - text="There is a person at the front door.", - voice_id="voice1", - model_id="model1", - voice_settings=tts_entity._voice_settings, - ) + assert len(stream_calls) == 1 + language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language) + call_kwargs = stream_calls[0] + assert call_kwargs["text"] == "There is a person at the front door." + assert call_kwargs["voice_id"] == "voice1" + assert call_kwargs["model_id"] == "model1" + assert call_kwargs["voice_settings"] == tts_entity._voice_settings + assert call_kwargs["output_format"] == "mp3_44100_128" + assert call_kwargs["language_code"] == language @pytest.mark.parametrize( @@ -257,16 +347,16 @@ async def test_tts_service_speak_error( setup: AsyncMock, hass: HomeAssistant, hass_client: ClientSessionGenerator, + capture_stream_calls, calls: list[ServiceCall], tts_service: str, service_data: dict[str, Any], ) -> None: """Test service call say with http response 400.""" + stream_calls, set_next_return, patch_stream = capture_stream_calls tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID]) - tts_entity._client.text_to_speech.convert = MagicMock( - return_value=FakeAudioGenerator() - ) - tts_entity._client.text_to_speech.convert.side_effect = ApiError + patch_stream(tts_entity) + set_next_return(chunks=[], request_id=None, error=ApiError()) await hass.services.async_call( tts.DOMAIN, @@ -275,18 +365,20 @@ async def test_tts_service_speak_error( blocking=True, ) - assert len(calls) == 1 assert ( await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) == HTTPStatus.INTERNAL_SERVER_ERROR ) - tts_entity._client.text_to_speech.convert.assert_called_once_with( - text="There is a person at the front door.", - voice_id="voice1", - model_id="model1", - voice_settings=tts_entity._voice_settings, - ) + assert len(stream_calls) == 1 + language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language) + call_kwargs = stream_calls[0] + assert call_kwargs["text"] == "There is a person at the front door." + assert call_kwargs["voice_id"] == "voice1" + assert call_kwargs["model_id"] == "model1" + assert call_kwargs["voice_settings"] == tts_entity._voice_settings + assert call_kwargs["output_format"] == "mp3_44100_128" + assert call_kwargs["language_code"] == language @pytest.mark.parametrize( @@ -323,16 +415,17 @@ async def test_tts_service_speak_voice_settings( setup: AsyncMock, hass: HomeAssistant, hass_client: ClientSessionGenerator, + capture_stream_calls, calls: list[ServiceCall], tts_service: str, service_data: dict[str, Any], mock_similarity: float, ) -> None: """Test tts service.""" + stream_calls, _, patch_stream = capture_stream_calls tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID]) - tts_entity._client.text_to_speech.convert = MagicMock( - return_value=FakeAudioGenerator() - ) + patch_stream(tts_entity) + assert tts_entity._voice_settings == VoiceSettings( stability=DEFAULT_STABILITY, similarity_boost=DEFAULT_SIMILARITY / 2, @@ -347,18 +440,20 @@ async def test_tts_service_speak_voice_settings( blocking=True, ) - assert len(calls) == 1 assert ( await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) == HTTPStatus.OK ) - tts_entity._client.text_to_speech.convert.assert_called_once_with( - text="There is a person at the front door.", - voice_id="voice2", - model_id="model1", - voice_settings=tts_entity._voice_settings, - ) + assert len(stream_calls) == 1 + language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language) + call_kwargs = stream_calls[0] + assert call_kwargs["text"] == "There is a person at the front door." + assert call_kwargs["voice_id"] == "voice2" + assert call_kwargs["model_id"] == "model1" + assert call_kwargs["voice_settings"] == tts_entity._voice_settings + assert call_kwargs["output_format"] == "mp3_44100_128" + assert call_kwargs["language_code"] == language @pytest.mark.parametrize( @@ -381,15 +476,15 @@ async def test_tts_service_speak_without_options( setup: AsyncMock, hass: HomeAssistant, hass_client: ClientSessionGenerator, + capture_stream_calls, calls: list[ServiceCall], tts_service: str, service_data: dict[str, Any], ) -> None: """Test service call say with http response 200.""" + stream_calls, _, patch_stream = capture_stream_calls tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID]) - tts_entity._client.text_to_speech.convert = MagicMock( - return_value=FakeAudioGenerator() - ) + patch_stream(tts_entity) await hass.services.async_call( tts.DOMAIN, @@ -398,17 +493,179 @@ async def test_tts_service_speak_without_options( blocking=True, ) - assert len(calls) == 1 assert ( await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) == HTTPStatus.OK ) - tts_entity._client.text_to_speech.convert.assert_called_once_with( - text="There is a person at the front door.", - voice_id="voice1", - voice_settings=VoiceSettings( - stability=0.5, similarity_boost=0.75, style=0.0, use_speaker_boost=True - ), - model_id="model1", + assert len(stream_calls) == 1 + language = service_data.get(tts.ATTR_LANGUAGE, tts_entity.default_language) + call_kwargs = stream_calls[0] + assert call_kwargs["text"] == "There is a person at the front door." + assert call_kwargs["voice_id"] == "voice1" + assert call_kwargs["model_id"] == "model1" + assert call_kwargs["voice_settings"] == VoiceSettings( + stability=0.5, + similarity_boost=0.75, + style=0.0, + use_speaker_boost=True, ) + assert call_kwargs["output_format"] == "mp3_44100_128" + assert call_kwargs["language_code"] == language + + +@pytest.mark.parametrize( + ("setup", "model_id"), + [ + ("mock_config_entry_setup", "eleven_multilingual_v2"), + ], + indirect=["setup"], +) +@pytest.mark.parametrize( + ("message", "chunks", "request_ids"), + [ + ( + [ + ["One. ", "Two! ", "Three"], + ["! ", "Four"], + ["? ", "Five"], + ["! ", "Six!"], + ], + [b"\x05\x06", b"\x07\x08", b"\x09\x0a", b"\x0b\x0c"], + ["rid-1", "rid-2", "rid-3", "rid-4"], + ), + ], +) +async def test_stream_tts_with_request_ids( + setup: AsyncMock, + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + capture_stream_calls, + stream_sentence_helpers, + model_id: str, + message: list[list[str]], + chunks: list[bytes], + request_ids: list[str], +) -> None: + """Test streaming TTS with request-id stitching.""" + calls, set_next_return, patch_stream = capture_stream_calls + + # Access the TTS entity as in your existing tests; adjust if you use a fixture instead + tts_entity = hass.data[tts.DOMAIN].get_entity("tts.elevenlabs_text_to_speech") + patch_stream(tts_entity) + + # Use a queue to control when each part is yielded + queue = asyncio.Queue() + prev_request_ids: deque[str] = deque(maxlen=3) # keep last 3 request IDs + sentence_iter = iter(zip(message, chunks, request_ids, strict=False)) + get_next_part, message_gen = stream_sentence_helpers(sentence_iter, queue) + options = {tts.ATTR_VOICE: "voice1", "model": model_id} + req = TTSAudioRequest(message_gen=message_gen(), language="en", options=options) + + resp = await tts_entity.async_stream_tts_audio(req) + assert resp.extension == "mp3" + + item, chunk, request_id = await get_next_part() + if item is not None: + for part in item: + await queue.put(part) + else: + await queue.put(None) + + set_next_return(chunks=[chunk], request_id=request_id) + next_item, next_chunk, next_request_id = await get_next_part() + # Consume bytes; after first chunk, switch next return to emulate second call + async for b in resp.data_gen: + assert b == chunk # each sentence yields its first chunk immediately + assert "previous_text" not in calls[-1] # no previous_text for first sentence + assert "next_text" not in calls[-1] # no next_text for first + assert calls[-1].get("previous_request_ids", []) == ( + [] if len(calls) == 1 else list(prev_request_ids) + ) + if request_id: + prev_request_ids.append(request_id or "") + item, chunk, request_id = next_item, next_chunk, next_request_id + if item is not None: + for part in item: + await queue.put(part) + set_next_return(chunks=[chunk], request_id=request_id) + next_item, next_chunk, next_request_id = await get_next_part() + if item is None: + await queue.put(None) + else: + await queue.put(None) + + # We expect two stream() invocations (one per sentence batch) + assert len(calls) == len(message) + + +@pytest.mark.parametrize( + ("message", "chunks", "request_ids"), + [ + ( + [ + ["This is the first sentence. ", "This is "], + ["the second sentence. "], + ], + [b"\x05\x06", b"\x07\x08"], + ["rid-1", "rid-2"], + ), + ], +) +async def test_stream_tts_without_previous_info( + setup: AsyncMock, + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + capture_stream_calls, + stream_sentence_helpers, + monkeypatch: pytest.MonkeyPatch, + message: list[list[str]], + chunks: list[bytes], + request_ids: list[str], +) -> None: + """Test streaming TTS without request-id stitching (eleven_v3).""" + calls, set_next_return, patch_stream = capture_stream_calls + tts_entity = hass.data[tts.DOMAIN].get_entity("tts.elevenlabs_text_to_speech") + patch_stream(tts_entity) + monkeypatch.setattr( + "homeassistant.components.elevenlabs.tts.MODELS_PREVIOUS_INFO_NOT_SUPPORTED", + ("model1",), + raising=False, + ) + + queue = asyncio.Queue() + sentence_iter = iter(zip(message, chunks, request_ids, strict=False)) + get_next_part, message_gen = stream_sentence_helpers(sentence_iter, queue) + options = {tts.ATTR_VOICE: "voice1", "model": "model1"} + req = TTSAudioRequest(message_gen=message_gen(), language="en", options=options) + + resp = await tts_entity.async_stream_tts_audio(req) + assert resp.extension == "mp3" + + item, chunk, request_id = await get_next_part() + if item is not None: + for part in item: + await queue.put(part) + else: + await queue.put(None) + + set_next_return(chunks=[chunk], request_id=request_id) + next_item, next_chunk, next_request_id = await get_next_part() + # Consume bytes; after first chunk, switch next return to emulate second call + async for b in resp.data_gen: + assert b == chunk # each sentence yields its first chunk immediately + assert "previous_request_ids" not in calls[-1] # no previous_request_ids + + item, chunk, request_id = next_item, next_chunk, next_request_id + if item is not None: + for part in item: + await queue.put(part) + set_next_return(chunks=[chunk], request_id=request_id) + next_item, next_chunk, next_request_id = await get_next_part() + if item is None: + await queue.put(None) + else: + await queue.put(None) + + # We expect two stream() invocations (one per sentence batch) + assert len(calls) == len(message)