diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 3645afedd6d..fb9dfcac13c 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -527,6 +527,8 @@ class ResultStream: This method will leverage a disk cache to speed up generation. """ + if self._result_cache.done(): + return self._result_cache.set_result( self._manager.async_cache_message_in_memory( engine=self.engine, @@ -543,6 +545,8 @@ class ResultStream: This method can result in faster first byte when generating long responses. """ + if self._result_cache.done(): + return self._result_cache.set_result( self._manager.async_cache_message_stream_in_memory( engine=self.engine, diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 5a6e988c82d..ee7878e603a 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -1954,6 +1954,46 @@ async def test_stream(hass: HomeAssistant, mock_tts_entity: MockTTSEntity) -> No assert result_data == data +async def test_result_stream_message_set_idempotent( + hass: HomeAssistant, mock_tts_entity: MockTTSEntity +) -> None: + """Test setting a result stream message more than once.""" + await mock_config_entry_setup(hass, mock_tts_entity) + + stream = tts.async_create_stream(hass, mock_tts_entity.entity_id) + stream.async_set_message("hello") + cache_first = stream._result_cache.result() + stream.async_set_message("world") + assert stream._result_cache.result() is cache_first + + async def async_stream_tts_audio( + request: tts.TTSAudioRequest, + ) -> tts.TTSAudioResponse: + """Mock stream TTS audio.""" + + async def gen_data(): + async for msg in request.message_gen: + yield msg.encode() + + return tts.TTSAudioResponse( + extension="mp3", + data_gen=gen_data(), + ) + + mock_tts_entity.async_stream_tts_audio = async_stream_tts_audio + mock_tts_entity.async_supports_streaming_input = Mock(return_value=True) + + async def stream_message(): + """Mock stream message.""" + yield "h" + + stream2 = tts.async_create_stream(hass, mock_tts_entity.entity_id) + stream2.async_set_message_stream(stream_message()) + cache_first = stream2._result_cache.result() + stream2.async_set_message_stream(stream_message()) + assert stream2._result_cache.result() is cache_first + + async def test_tts_cache() -> None: """Test TTSCache."""