mirror of
https://github.com/home-assistant/core.git
synced 2026-04-02 08:26:41 +01:00
Ensure STT metadata enums are passed (#165220)
This commit is contained in:
@@ -397,11 +397,11 @@ def _metadata_from_header(request: web.Request) -> SpeechMetadata:
|
||||
try:
|
||||
return SpeechMetadata(
|
||||
language=args["language"],
|
||||
format=args["format"],
|
||||
codec=args["codec"],
|
||||
bit_rate=args["bit_rate"],
|
||||
sample_rate=args["sample_rate"],
|
||||
channel=args["channel"],
|
||||
format=AudioFormats(args["format"]),
|
||||
codec=AudioCodecs(args["codec"]),
|
||||
bit_rate=AudioBitRates(int(args["bit_rate"])),
|
||||
sample_rate=AudioSampleRates(int(args["sample_rate"])),
|
||||
channel=AudioChannels(int(args["channel"])),
|
||||
)
|
||||
except ValueError as err:
|
||||
raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err
|
||||
|
||||
@@ -23,12 +23,6 @@ class SpeechMetadata:
|
||||
sample_rate: AudioSampleRates
|
||||
channel: AudioChannels
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Finish initializing the metadata."""
|
||||
self.bit_rate = AudioBitRates(int(self.bit_rate))
|
||||
self.sample_rate = AudioSampleRates(int(self.sample_rate))
|
||||
self.channel = AudioChannels(int(self.channel))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechResult:
|
||||
|
||||
@@ -10,6 +10,11 @@ import pytest
|
||||
|
||||
from homeassistant.components.stt import (
|
||||
DOMAIN,
|
||||
AudioBitRates,
|
||||
AudioChannels,
|
||||
AudioCodecs,
|
||||
AudioFormats,
|
||||
AudioSampleRates,
|
||||
async_default_engine,
|
||||
async_get_provider,
|
||||
async_get_speech_to_text_engine,
|
||||
@@ -233,6 +238,42 @@ async def test_stream_audio(
|
||||
assert await response.json() == {"text": "test_result", "result": "success"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
|
||||
)
|
||||
async def test_stream_audio_uses_enum_values(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
setup: MockSTTProvider | MockSTTProviderEntity,
|
||||
) -> None:
|
||||
"""Test that HTTP API passes enum values to async_process_audio_stream."""
|
||||
client = await hass_client()
|
||||
response = await client.post(
|
||||
f"/api/stt/{setup.url_path}",
|
||||
headers={
|
||||
"X-Speech-Content": (
|
||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
|
||||
" language=en"
|
||||
)
|
||||
},
|
||||
)
|
||||
assert response.status == HTTPStatus.OK
|
||||
|
||||
assert len(setup.calls) == 1
|
||||
metadata, _ = setup.calls[0]
|
||||
|
||||
assert isinstance(metadata.format, AudioFormats)
|
||||
assert metadata.format == AudioFormats.WAV
|
||||
assert isinstance(metadata.codec, AudioCodecs)
|
||||
assert metadata.codec == AudioCodecs.PCM
|
||||
assert isinstance(metadata.bit_rate, AudioBitRates)
|
||||
assert metadata.bit_rate == AudioBitRates.BITRATE_16
|
||||
assert isinstance(metadata.sample_rate, AudioSampleRates)
|
||||
assert metadata.sample_rate == AudioSampleRates.SAMPLERATE_16000
|
||||
assert isinstance(metadata.channel, AudioChannels)
|
||||
assert metadata.channel == AudioChannels.CHANNEL_MONO
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"setup", ["mock_setup", "mock_config_entry_setup"], indirect=True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user