1
0
mirror of https://github.com/home-assistant/core.git synced 2026-04-02 08:26:41 +01:00

Move pipeline input validation into execute method (#166373)

This commit is contained in:
Artur Pragacz
2026-03-24 19:32:13 +01:00
committed by GitHub
parent 001a1aada6
commit b42bd4909b
4 changed files with 78 additions and 66 deletions

View File

@@ -24,7 +24,7 @@ from .const import (
SAMPLE_WIDTH,
SAMPLES_PER_CHUNK,
)
from .error import PipelineError, PipelineNotFound
from .error import PipelineNotFound
from .pipeline import (
AudioSettings,
Pipeline,
@@ -137,21 +137,4 @@ async def async_pipeline_from_audio_stream(
audio_settings=audio_settings or AudioSettings(),
),
)
try:
await pipeline_input.validate()
except PipelineError as err:
pipeline_input.run.start(
conversation_id=session.conversation_id,
device_id=device_id,
satellite_id=satellite_id,
)
pipeline_input.run.process_event(
PipelineEvent(
PipelineEventType.ERROR,
{"code": err.code, "message": err.message},
)
)
await pipeline_input.run.end()
return
await pipeline_input.execute()
await pipeline_input.execute(validate=True)

View File

@@ -1,7 +1,14 @@
"""Assist pipeline errors."""
from __future__ import annotations
from typing import TYPE_CHECKING
from homeassistant.exceptions import HomeAssistantError
if TYPE_CHECKING:
from .pipeline import PipelineStage
class PipelineError(HomeAssistantError):
"""Base class for pipeline errors."""
@@ -55,3 +62,25 @@ class IntentRecognitionError(PipelineError):
class TextToSpeechError(PipelineError):
"""Error in text-to-speech portion of pipeline."""
class PipelineRunValidationError(PipelineError):
"""Error when a pipeline run is not valid."""
def __init__(self, message: str) -> None:
"""Set error message."""
super().__init__("validation-error", message)
class InvalidPipelineStagesError(PipelineRunValidationError):
"""Error when given an invalid combination of start/end stages."""
def __init__(
self,
start_stage: PipelineStage,
end_stage: PipelineStage,
) -> None:
"""Set error message."""
super().__init__(
f"Invalid stage combination: start={start_stage}, end={end_stage}"
)

View File

@@ -73,8 +73,10 @@ from .const import (
from .error import (
DuplicateWakeUpDetectedError,
IntentRecognitionError,
InvalidPipelineStagesError,
PipelineError,
PipelineNotFound,
PipelineRunValidationError,
SpeechToTextError,
TextToSpeechError,
WakeWordDetectionAborted,
@@ -492,24 +494,6 @@ PIPELINE_STAGE_ORDER = [
]
class PipelineRunValidationError(Exception):
"""Error when a pipeline run is not valid."""
class InvalidPipelineStagesError(PipelineRunValidationError):
"""Error when given an invalid combination of start/end stages."""
def __init__(
self,
start_stage: PipelineStage,
end_stage: PipelineStage,
) -> None:
"""Set error message."""
super().__init__(
f"Invalid stage combination: start={start_stage}, end={end_stage}"
)
@dataclass(frozen=True)
class WakeWordSettings:
"""Settings for wake word detection."""
@@ -1680,26 +1664,39 @@ class PipelineInput:
satellite_id: str | None = None
"""Identifier of the satellite that is processing the input/output of the pipeline."""
async def execute(self) -> None:
async def execute(self, validate: bool = False) -> None:
"""Run pipeline."""
validation_error: PipelineError | None = None
if validate:
try:
await self.validate()
except PipelineError as err:
validation_error = err
self.run.start(
conversation_id=self.session.conversation_id,
device_id=self.device_id,
satellite_id=self.satellite_id,
)
current_stage: PipelineStage | None = self.run.start_stage
stt_audio_buffer: list[EnhancedAudioChunk] = []
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
if self.stt_stream is not None:
if self.run.audio_settings.needs_processor:
# VAD/noise suppression/auto gain/volume
stt_processed_stream = self.run.process_enhance_audio(self.stt_stream)
else:
# Volume multiplier only
stt_processed_stream = self.run.process_volume_only(self.stt_stream)
try:
if validation_error is not None:
raise validation_error
stt_audio_buffer: list[EnhancedAudioChunk] = []
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
if self.stt_stream is not None:
if self.run.audio_settings.needs_processor:
# VAD/noise suppression/auto gain/volume
stt_processed_stream = self.run.process_enhance_audio(
self.stt_stream
)
else:
# Volume multiplier only
stt_processed_stream = self.run.process_volume_only(self.stt_stream)
if current_stage == PipelineStage.WAKE_WORD:
# wake-word-detection
assert stt_processed_stream is not None

View File

@@ -270,25 +270,28 @@ async def test_pipeline_from_audio_stream_no_stt(
pipeline_id = msg["result"]["id"]
# Try to use the created pipeline
with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
)
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
)
assert not events
assert len(events) == 3
assert events[0].type == assist_pipeline.PipelineEventType.RUN_START
assert events[1].type == assist_pipeline.PipelineEventType.ERROR
assert events[1].data["code"] == "validation-error"
assert events[2].type == assist_pipeline.PipelineEventType.RUN_END
async def test_pipeline_from_audio_stream_unknown_pipeline(