mirror of
https://github.com/home-assistant/core.git
synced 2026-04-02 08:26:41 +01:00
Move pipeline input validation into execute method (#166373)
This commit is contained in:
@@ -24,7 +24,7 @@ from .const import (
|
||||
SAMPLE_WIDTH,
|
||||
SAMPLES_PER_CHUNK,
|
||||
)
|
||||
from .error import PipelineError, PipelineNotFound
|
||||
from .error import PipelineNotFound
|
||||
from .pipeline import (
|
||||
AudioSettings,
|
||||
Pipeline,
|
||||
@@ -137,21 +137,4 @@ async def async_pipeline_from_audio_stream(
|
||||
audio_settings=audio_settings or AudioSettings(),
|
||||
),
|
||||
)
|
||||
try:
|
||||
await pipeline_input.validate()
|
||||
except PipelineError as err:
|
||||
pipeline_input.run.start(
|
||||
conversation_id=session.conversation_id,
|
||||
device_id=device_id,
|
||||
satellite_id=satellite_id,
|
||||
)
|
||||
pipeline_input.run.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.ERROR,
|
||||
{"code": err.code, "message": err.message},
|
||||
)
|
||||
)
|
||||
await pipeline_input.run.end()
|
||||
return
|
||||
|
||||
await pipeline_input.execute()
|
||||
await pipeline_input.execute(validate=True)
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
"""Assist pipeline errors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pipeline import PipelineStage
|
||||
|
||||
|
||||
class PipelineError(HomeAssistantError):
|
||||
"""Base class for pipeline errors."""
|
||||
@@ -55,3 +62,25 @@ class IntentRecognitionError(PipelineError):
|
||||
|
||||
class TextToSpeechError(PipelineError):
|
||||
"""Error in text-to-speech portion of pipeline."""
|
||||
|
||||
|
||||
class PipelineRunValidationError(PipelineError):
|
||||
"""Error when a pipeline run is not valid."""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
"""Set error message."""
|
||||
super().__init__("validation-error", message)
|
||||
|
||||
|
||||
class InvalidPipelineStagesError(PipelineRunValidationError):
|
||||
"""Error when given an invalid combination of start/end stages."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_stage: PipelineStage,
|
||||
end_stage: PipelineStage,
|
||||
) -> None:
|
||||
"""Set error message."""
|
||||
super().__init__(
|
||||
f"Invalid stage combination: start={start_stage}, end={end_stage}"
|
||||
)
|
||||
|
||||
@@ -73,8 +73,10 @@ from .const import (
|
||||
from .error import (
|
||||
DuplicateWakeUpDetectedError,
|
||||
IntentRecognitionError,
|
||||
InvalidPipelineStagesError,
|
||||
PipelineError,
|
||||
PipelineNotFound,
|
||||
PipelineRunValidationError,
|
||||
SpeechToTextError,
|
||||
TextToSpeechError,
|
||||
WakeWordDetectionAborted,
|
||||
@@ -492,24 +494,6 @@ PIPELINE_STAGE_ORDER = [
|
||||
]
|
||||
|
||||
|
||||
class PipelineRunValidationError(Exception):
|
||||
"""Error when a pipeline run is not valid."""
|
||||
|
||||
|
||||
class InvalidPipelineStagesError(PipelineRunValidationError):
|
||||
"""Error when given an invalid combination of start/end stages."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_stage: PipelineStage,
|
||||
end_stage: PipelineStage,
|
||||
) -> None:
|
||||
"""Set error message."""
|
||||
super().__init__(
|
||||
f"Invalid stage combination: start={start_stage}, end={end_stage}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WakeWordSettings:
|
||||
"""Settings for wake word detection."""
|
||||
@@ -1680,26 +1664,39 @@ class PipelineInput:
|
||||
satellite_id: str | None = None
|
||||
"""Identifier of the satellite that is processing the input/output of the pipeline."""
|
||||
|
||||
async def execute(self) -> None:
|
||||
async def execute(self, validate: bool = False) -> None:
|
||||
"""Run pipeline."""
|
||||
validation_error: PipelineError | None = None
|
||||
if validate:
|
||||
try:
|
||||
await self.validate()
|
||||
except PipelineError as err:
|
||||
validation_error = err
|
||||
|
||||
self.run.start(
|
||||
conversation_id=self.session.conversation_id,
|
||||
device_id=self.device_id,
|
||||
satellite_id=self.satellite_id,
|
||||
)
|
||||
current_stage: PipelineStage | None = self.run.start_stage
|
||||
stt_audio_buffer: list[EnhancedAudioChunk] = []
|
||||
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
|
||||
|
||||
if self.stt_stream is not None:
|
||||
if self.run.audio_settings.needs_processor:
|
||||
# VAD/noise suppression/auto gain/volume
|
||||
stt_processed_stream = self.run.process_enhance_audio(self.stt_stream)
|
||||
else:
|
||||
# Volume multiplier only
|
||||
stt_processed_stream = self.run.process_volume_only(self.stt_stream)
|
||||
|
||||
try:
|
||||
if validation_error is not None:
|
||||
raise validation_error
|
||||
|
||||
stt_audio_buffer: list[EnhancedAudioChunk] = []
|
||||
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
|
||||
|
||||
if self.stt_stream is not None:
|
||||
if self.run.audio_settings.needs_processor:
|
||||
# VAD/noise suppression/auto gain/volume
|
||||
stt_processed_stream = self.run.process_enhance_audio(
|
||||
self.stt_stream
|
||||
)
|
||||
else:
|
||||
# Volume multiplier only
|
||||
stt_processed_stream = self.run.process_volume_only(self.stt_stream)
|
||||
|
||||
if current_stage == PipelineStage.WAKE_WORD:
|
||||
# wake-word-detection
|
||||
assert stt_processed_stream is not None
|
||||
|
||||
@@ -270,25 +270,28 @@ async def test_pipeline_from_audio_stream_no_stt(
|
||||
pipeline_id = msg["result"]["id"]
|
||||
|
||||
# Try to use the created pipeline
|
||||
with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError):
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="en-UK",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
pipeline_id=pipeline_id,
|
||||
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
|
||||
)
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
context=Context(),
|
||||
event_callback=events.append,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="en-UK",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
pipeline_id=pipeline_id,
|
||||
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
|
||||
)
|
||||
|
||||
assert not events
|
||||
assert len(events) == 3
|
||||
assert events[0].type == assist_pipeline.PipelineEventType.RUN_START
|
||||
assert events[1].type == assist_pipeline.PipelineEventType.ERROR
|
||||
assert events[1].data["code"] == "validation-error"
|
||||
assert events[2].type == assist_pipeline.PipelineEventType.RUN_END
|
||||
|
||||
|
||||
async def test_pipeline_from_audio_stream_unknown_pipeline(
|
||||
|
||||
Reference in New Issue
Block a user