mirror of
https://github.com/home-assistant/core.git
synced 2025-12-24 21:06:19 +00:00
Always chunk Wyoming TTS audio (#156079)
This commit is contained in:
@@ -54,7 +54,7 @@ _PING_TIMEOUT: Final = 5
|
||||
_PING_SEND_DELAY: Final = 2
|
||||
_PIPELINE_FINISH_TIMEOUT: Final = 1
|
||||
_TTS_SAMPLE_RATE: Final = 22050
|
||||
_ANNOUNCE_CHUNK_BYTES: Final = 2048 # 1024 samples
|
||||
_AUDIO_CHUNK_BYTES: Final = 2048 # 1024 samples
|
||||
_TTS_TIMEOUT_EXTRA: Final = 1.0
|
||||
|
||||
# Wyoming stage -> Assist stage
|
||||
@@ -360,7 +360,7 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
)
|
||||
assert proc.stdout is not None
|
||||
while True:
|
||||
chunk_bytes = await proc.stdout.read(_ANNOUNCE_CHUNK_BYTES)
|
||||
chunk_bytes = await proc.stdout.read(_AUDIO_CHUNK_BYTES)
|
||||
if not chunk_bytes:
|
||||
break
|
||||
|
||||
@@ -782,17 +782,22 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
assert sample_width is not None
|
||||
assert sample_channels is not None
|
||||
|
||||
audio_chunk = AudioChunk(
|
||||
rate=sample_rate,
|
||||
width=sample_width,
|
||||
channels=sample_channels,
|
||||
audio=data_chunk,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
data_chunk_idx = 0
|
||||
while data_chunk_idx < len(data_chunk):
|
||||
audio_chunk = AudioChunk(
|
||||
rate=sample_rate,
|
||||
width=sample_width,
|
||||
channels=sample_channels,
|
||||
audio=data_chunk[
|
||||
data_chunk_idx : data_chunk_idx + _AUDIO_CHUNK_BYTES
|
||||
],
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
await self._client.write_event(audio_chunk.event())
|
||||
timestamp += audio_chunk.milliseconds
|
||||
total_seconds += audio_chunk.seconds
|
||||
await self._client.write_event(audio_chunk.event())
|
||||
timestamp += audio_chunk.milliseconds
|
||||
total_seconds += audio_chunk.seconds
|
||||
data_chunk_idx += _AUDIO_CHUNK_BYTES
|
||||
|
||||
await self._client.write_event(AudioStop(timestamp=timestamp).event())
|
||||
_LOGGER.debug("TTS streaming complete")
|
||||
|
||||
@@ -59,7 +59,7 @@ async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
return entry
|
||||
|
||||
|
||||
def get_test_wav() -> bytes:
|
||||
def get_test_wav(chunk_copies: int = 1) -> bytes:
|
||||
"""Get bytes for test WAV file."""
|
||||
with io.BytesIO() as wav_io:
|
||||
with wave.open(wav_io, "wb") as wav_file:
|
||||
@@ -68,7 +68,7 @@ def get_test_wav() -> bytes:
|
||||
wav_file.setnchannels(1)
|
||||
|
||||
# Single frame
|
||||
wav_file.writeframes(b"1234")
|
||||
wav_file.writeframes(b"1234" * chunk_copies)
|
||||
|
||||
return wav_io.getvalue()
|
||||
|
||||
@@ -111,6 +111,7 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
||||
self.tts_audio_chunk_event = asyncio.Event()
|
||||
self.tts_audio_stop_event = asyncio.Event()
|
||||
self.tts_audio_chunk: AudioChunk | None = None
|
||||
self.tts_audio_chunks: list[AudioChunk] = []
|
||||
|
||||
self.error_event = asyncio.Event()
|
||||
self.error: Error | None = None
|
||||
@@ -169,6 +170,7 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
||||
self.tts_audio_start_event.set()
|
||||
elif AudioChunk.is_type(event.type):
|
||||
self.tts_audio_chunk = AudioChunk.from_event(event)
|
||||
self.tts_audio_chunks.append(self.tts_audio_chunk)
|
||||
self.tts_audio_chunk_event.set()
|
||||
elif AudioStop.is_type(event.type):
|
||||
self.tts_audio_stop_event.set()
|
||||
@@ -1537,7 +1539,7 @@ async def test_satellite_tts_streaming(hass: HomeAssistant) -> None:
|
||||
assert pipeline_kwargs.get("device_id") == device.device_id
|
||||
|
||||
# Send TTS info early
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", get_test_wav())
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", get_test_wav(1000))
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.RUN_START,
|
||||
@@ -1604,12 +1606,14 @@ async def test_satellite_tts_streaming(hass: HomeAssistant) -> None:
|
||||
await mock_client.tts_audio_chunk_event.wait()
|
||||
await mock_client.tts_audio_stop_event.wait()
|
||||
|
||||
# Verify audio chunk from test WAV
|
||||
assert mock_client.tts_audio_chunk is not None
|
||||
assert mock_client.tts_audio_chunk.rate == 22050
|
||||
assert mock_client.tts_audio_chunk.width == 2
|
||||
assert mock_client.tts_audio_chunk.channels == 1
|
||||
assert mock_client.tts_audio_chunk.audio == b"1234"
|
||||
# Verify audio chunks from test WAV
|
||||
assert len(mock_client.tts_audio_chunks) == 2
|
||||
chunk_sizes = (2048, 1952) # 1024 samples per chunk
|
||||
for i, audio_chunk in enumerate(mock_client.tts_audio_chunks):
|
||||
assert audio_chunk.rate == 22050
|
||||
assert audio_chunk.width == 2
|
||||
assert audio_chunk.channels == 1
|
||||
assert len(audio_chunk.audio) == chunk_sizes[i]
|
||||
|
||||
# Text-to-speech text
|
||||
pipeline_event_callback(
|
||||
|
||||
Reference in New Issue
Block a user