diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index ea344df038d..739d75a08ad 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -24,7 +24,7 @@ from .const import ( SAMPLE_WIDTH, SAMPLES_PER_CHUNK, ) -from .error import PipelineError, PipelineNotFound +from .error import PipelineNotFound from .pipeline import ( AudioSettings, Pipeline, @@ -137,21 +137,4 @@ async def async_pipeline_from_audio_stream( audio_settings=audio_settings or AudioSettings(), ), ) - try: - await pipeline_input.validate() - except PipelineError as err: - pipeline_input.run.start( - conversation_id=session.conversation_id, - device_id=device_id, - satellite_id=satellite_id, - ) - pipeline_input.run.process_event( - PipelineEvent( - PipelineEventType.ERROR, - {"code": err.code, "message": err.message}, - ) - ) - await pipeline_input.run.end() - return - - await pipeline_input.execute() + await pipeline_input.execute(validate=True) diff --git a/homeassistant/components/assist_pipeline/error.py b/homeassistant/components/assist_pipeline/error.py index 8b72331817c..d12f41ce144 100644 --- a/homeassistant/components/assist_pipeline/error.py +++ b/homeassistant/components/assist_pipeline/error.py @@ -1,7 +1,14 @@ """Assist pipeline errors.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + from homeassistant.exceptions import HomeAssistantError +if TYPE_CHECKING: + from .pipeline import PipelineStage + class PipelineError(HomeAssistantError): """Base class for pipeline errors.""" @@ -55,3 +62,25 @@ class IntentRecognitionError(PipelineError): class TextToSpeechError(PipelineError): """Error in text-to-speech portion of pipeline.""" + + +class PipelineRunValidationError(PipelineError): + """Error when a pipeline run is not valid.""" + + def __init__(self, message: str) -> None: + """Set error message.""" + super().__init__("validation-error", message) + + +class InvalidPipelineStagesError(PipelineRunValidationError): + """Error when given an invalid combination of start/end stages.""" + + def __init__( + self, + start_stage: PipelineStage, + end_stage: PipelineStage, + ) -> None: + """Set error message.""" + super().__init__( + f"Invalid stage combination: start={start_stage}, end={end_stage}" + ) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index a6d402dd8e3..67aa14f04d7 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -73,8 +73,10 @@ from .const import ( from .error import ( DuplicateWakeUpDetectedError, IntentRecognitionError, + InvalidPipelineStagesError, PipelineError, PipelineNotFound, + PipelineRunValidationError, SpeechToTextError, TextToSpeechError, WakeWordDetectionAborted, @@ -492,24 +494,6 @@ PIPELINE_STAGE_ORDER = [ ] -class PipelineRunValidationError(Exception): - """Error when a pipeline run is not valid.""" - - -class InvalidPipelineStagesError(PipelineRunValidationError): - """Error when given an invalid combination of start/end stages.""" - - def __init__( - self, - start_stage: PipelineStage, - end_stage: PipelineStage, - ) -> None: - """Set error message.""" - super().__init__( - f"Invalid stage combination: start={start_stage}, end={end_stage}" - ) - - @dataclass(frozen=True) class WakeWordSettings: """Settings for wake word detection.""" @@ -1680,26 +1664,39 @@ class PipelineInput: satellite_id: str | None = None """Identifier of the satellite that is processing the input/output of the pipeline.""" - async def execute(self) -> None: + async def execute(self, validate: bool = False) -> None: """Run pipeline.""" + validation_error: PipelineError | None = None + if validate: + try: + await self.validate() + except PipelineError as err: + validation_error = err + self.run.start( conversation_id=self.session.conversation_id, device_id=self.device_id, satellite_id=self.satellite_id, ) current_stage: PipelineStage | None = self.run.start_stage - stt_audio_buffer: list[EnhancedAudioChunk] = [] - stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None - - if self.stt_stream is not None: - if self.run.audio_settings.needs_processor: - # VAD/noise suppression/auto gain/volume - stt_processed_stream = self.run.process_enhance_audio(self.stt_stream) - else: - # Volume multiplier only - stt_processed_stream = self.run.process_volume_only(self.stt_stream) try: + if validation_error is not None: + raise validation_error + + stt_audio_buffer: list[EnhancedAudioChunk] = [] + stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None + + if self.stt_stream is not None: + if self.run.audio_settings.needs_processor: + # VAD/noise suppression/auto gain/volume + stt_processed_stream = self.run.process_enhance_audio( + self.stt_stream + ) + else: + # Volume multiplier only + stt_processed_stream = self.run.process_volume_only(self.stt_stream) + if current_stage == PipelineStage.WAKE_WORD: # wake-word-detection assert stt_processed_stream is not None diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 3c2cdbfb0f8..bfa8ab0e452 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -270,25 +270,28 @@ async def test_pipeline_from_audio_stream_no_stt( pipeline_id = msg["result"]["id"] # Try to use the created pipeline - with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError): - await assist_pipeline.async_pipeline_from_audio_stream( - hass, - context=Context(), - event_callback=events.append, - stt_metadata=stt.SpeechMetadata( - language="en-UK", - format=stt.AudioFormats.WAV, - codec=stt.AudioCodecs.PCM, - bit_rate=stt.AudioBitRates.BITRATE_16, - sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, - channel=stt.AudioChannels.CHANNEL_MONO, - ), - stt_stream=audio_data(), - pipeline_id=pipeline_id, - audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), - ) + await assist_pipeline.async_pipeline_from_audio_stream( + hass, + context=Context(), + event_callback=events.append, + stt_metadata=stt.SpeechMetadata( + language="en-UK", + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + stt_stream=audio_data(), + pipeline_id=pipeline_id, + audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), + ) - assert not events + assert len(events) == 3 + assert events[0].type == assist_pipeline.PipelineEventType.RUN_START + assert events[1].type == assist_pipeline.PipelineEventType.ERROR + assert events[1].data["code"] == "validation-error" + assert events[2].type == assist_pipeline.PipelineEventType.RUN_END async def test_pipeline_from_audio_stream_unknown_pipeline(