mirror of
https://github.com/home-assistant/core.git
synced 2026-02-15 07:36:16 +00:00
Add TTS support for OpenAI (#162468)
Co-authored-by: Norbert Rittel <norbert@rittel.de> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
This commit is contained in:
@@ -49,6 +49,7 @@ from .const import (
|
||||
CONF_TOP_P,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_NAME,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
@@ -57,13 +58,14 @@ from .const import (
|
||||
RECOMMENDED_REASONING_EFFORT,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_P,
|
||||
RECOMMENDED_TTS_OPTIONS,
|
||||
)
|
||||
from .entity import async_prepare_files_for_prompt
|
||||
|
||||
SERVICE_GENERATE_IMAGE = "generate_image"
|
||||
SERVICE_GENERATE_CONTENT = "generate_content"
|
||||
|
||||
PLATFORMS = (Platform.AI_TASK, Platform.CONVERSATION)
|
||||
PLATFORMS = (Platform.AI_TASK, Platform.CONVERSATION, Platform.TTS)
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
|
||||
@@ -441,6 +443,10 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) ->
|
||||
)
|
||||
hass.config_entries.async_update_entry(entry, minor_version=4)
|
||||
|
||||
if entry.version == 2 and entry.minor_version == 4:
|
||||
_add_tts_subentry(hass, entry)
|
||||
hass.config_entries.async_update_entry(entry, minor_version=5)
|
||||
|
||||
LOGGER.debug(
|
||||
"Migration to version %s:%s successful", entry.version, entry.minor_version
|
||||
)
|
||||
@@ -459,3 +465,16 @@ def _add_ai_task_subentry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> None
|
||||
unique_id=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _add_tts_subentry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> None:
|
||||
"""Add TTS subentry to the config entry."""
|
||||
hass.config_entries.async_add_subentry(
|
||||
entry,
|
||||
ConfigSubentry(
|
||||
data=MappingProxyType(RECOMMENDED_TTS_OPTIONS),
|
||||
subentry_type="tts",
|
||||
title=DEFAULT_TTS_NAME,
|
||||
unique_id=None,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -39,6 +39,9 @@ from homeassistant.helpers.selector import (
|
||||
SelectSelectorConfig,
|
||||
SelectSelectorMode,
|
||||
TemplateSelector,
|
||||
TextSelector,
|
||||
TextSelectorConfig,
|
||||
TextSelectorType,
|
||||
)
|
||||
from homeassistant.helpers.typing import VolDictType
|
||||
|
||||
@@ -53,6 +56,7 @@ from .const import (
|
||||
CONF_RECOMMENDED,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
CONF_TTS_SPEED,
|
||||
CONF_VERBOSITY,
|
||||
CONF_WEB_SEARCH,
|
||||
CONF_WEB_SEARCH_CITY,
|
||||
@@ -64,6 +68,7 @@ from .const import (
|
||||
CONF_WEB_SEARCH_USER_LOCATION,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
@@ -75,6 +80,8 @@ from .const import (
|
||||
RECOMMENDED_REASONING_SUMMARY,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_P,
|
||||
RECOMMENDED_TTS_OPTIONS,
|
||||
RECOMMENDED_TTS_SPEED,
|
||||
RECOMMENDED_VERBOSITY,
|
||||
RECOMMENDED_WEB_SEARCH,
|
||||
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
|
||||
@@ -110,7 +117,7 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for OpenAI Conversation."""
|
||||
|
||||
VERSION = 2
|
||||
MINOR_VERSION = 4
|
||||
MINOR_VERSION = 5
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
@@ -151,6 +158,12 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"title": DEFAULT_AI_TASK_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
{
|
||||
"subentry_type": "tts",
|
||||
"data": RECOMMENDED_TTS_OPTIONS,
|
||||
"title": DEFAULT_TTS_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@@ -191,13 +204,13 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
return {
|
||||
"conversation": OpenAISubentryFlowHandler,
|
||||
"ai_task_data": OpenAISubentryFlowHandler,
|
||||
"tts": OpenAISubentryTTSFlowHandler,
|
||||
}
|
||||
|
||||
|
||||
class OpenAISubentryFlowHandler(ConfigSubentryFlow):
|
||||
"""Flow for managing OpenAI subentries."""
|
||||
|
||||
last_rendered_recommended = False
|
||||
options: dict[str, Any]
|
||||
|
||||
@property
|
||||
@@ -580,3 +593,77 @@ class OpenAISubentryFlowHandler(ConfigSubentryFlow):
|
||||
_LOGGER.debug("Location data: %s", location_data)
|
||||
|
||||
return location_data
|
||||
|
||||
|
||||
class OpenAISubentryTTSFlowHandler(ConfigSubentryFlow):
|
||||
"""Flow for managing OpenAI TTS subentries."""
|
||||
|
||||
options: dict[str, Any]
|
||||
|
||||
@property
|
||||
def _is_new(self) -> bool:
|
||||
"""Return if this is a new subentry."""
|
||||
return self.source == "user"
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> SubentryFlowResult:
|
||||
"""Add a subentry."""
|
||||
self.options = RECOMMENDED_TTS_OPTIONS.copy()
|
||||
return await self.async_step_init()
|
||||
|
||||
async def async_step_reconfigure(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> SubentryFlowResult:
|
||||
"""Handle reconfiguration of a subentry."""
|
||||
self.options = self._get_reconfigure_subentry().data.copy()
|
||||
return await self.async_step_init()
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> SubentryFlowResult:
|
||||
"""Manage initial options."""
|
||||
# abort if entry is not loaded
|
||||
if self._get_entry().state != ConfigEntryState.LOADED:
|
||||
return self.async_abort(reason="entry_not_loaded")
|
||||
|
||||
options = self.options
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
step_schema: VolDictType = {}
|
||||
|
||||
if self._is_new:
|
||||
step_schema[vol.Required(CONF_NAME, default=DEFAULT_TTS_NAME)] = str
|
||||
|
||||
step_schema.update(
|
||||
{
|
||||
vol.Optional(CONF_PROMPT): TextSelector(
|
||||
TextSelectorConfig(multiline=True, type=TextSelectorType.TEXT)
|
||||
),
|
||||
vol.Optional(
|
||||
CONF_TTS_SPEED, default=RECOMMENDED_TTS_SPEED
|
||||
): NumberSelector(NumberSelectorConfig(min=0.25, max=4.0, step=0.01)),
|
||||
}
|
||||
)
|
||||
|
||||
if user_input is not None:
|
||||
options.update(user_input)
|
||||
if not errors:
|
||||
if self._is_new:
|
||||
return self.async_create_entry(
|
||||
title=options.pop(CONF_NAME),
|
||||
data=options,
|
||||
)
|
||||
return self.async_update_and_abort(
|
||||
self._get_entry(),
|
||||
self._get_reconfigure_subentry(),
|
||||
data=options,
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=self.add_suggested_values_to_schema(
|
||||
vol.Schema(step_schema), options
|
||||
),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ LOGGER: logging.Logger = logging.getLogger(__package__)
|
||||
|
||||
DEFAULT_CONVERSATION_NAME = "OpenAI Conversation"
|
||||
DEFAULT_AI_TASK_NAME = "OpenAI AI Task"
|
||||
DEFAULT_TTS_NAME = "OpenAI TTS"
|
||||
DEFAULT_NAME = "OpenAI Conversation"
|
||||
|
||||
CONF_CHAT_MODEL = "chat_model"
|
||||
@@ -23,6 +24,7 @@ CONF_REASONING_SUMMARY = "reasoning_summary"
|
||||
CONF_RECOMMENDED = "recommended"
|
||||
CONF_TEMPERATURE = "temperature"
|
||||
CONF_TOP_P = "top_p"
|
||||
CONF_TTS_SPEED = "tts_speed"
|
||||
CONF_VERBOSITY = "verbosity"
|
||||
CONF_WEB_SEARCH = "web_search"
|
||||
CONF_WEB_SEARCH_USER_LOCATION = "user_location"
|
||||
@@ -40,6 +42,7 @@ RECOMMENDED_REASONING_EFFORT = "low"
|
||||
RECOMMENDED_REASONING_SUMMARY = "auto"
|
||||
RECOMMENDED_TEMPERATURE = 1.0
|
||||
RECOMMENDED_TOP_P = 1.0
|
||||
RECOMMENDED_TTS_SPEED = 1.0
|
||||
RECOMMENDED_VERBOSITY = "medium"
|
||||
RECOMMENDED_WEB_SEARCH = False
|
||||
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE = "medium"
|
||||
@@ -105,3 +108,7 @@ RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||
RECOMMENDED_AI_TASK_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
}
|
||||
RECOMMENDED_TTS_OPTIONS = {
|
||||
CONF_PROMPT: "",
|
||||
CONF_CHAT_MODEL: "gpt-4o-mini-tts",
|
||||
}
|
||||
|
||||
@@ -460,7 +460,7 @@ class OpenAIBaseLLMEntity(Entity):
|
||||
"""OpenAI conversation agent."""
|
||||
|
||||
_attr_has_entity_name = True
|
||||
_attr_name = None
|
||||
_attr_name: str | None = None
|
||||
|
||||
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
|
||||
"""Initialize the entity."""
|
||||
|
||||
@@ -145,6 +145,30 @@
|
||||
"title": "Model-specific options"
|
||||
}
|
||||
}
|
||||
},
|
||||
"tts": {
|
||||
"abort": {
|
||||
"entry_not_loaded": "[%key:component::openai_conversation::config_subentries::conversation::abort::entry_not_loaded%]",
|
||||
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
|
||||
},
|
||||
"entry_type": "Text-to-speech",
|
||||
"initiate_flow": {
|
||||
"reconfigure": "Reconfigure text-to-speech service",
|
||||
"user": "Add text-to-speech service"
|
||||
},
|
||||
"step": {
|
||||
"init": {
|
||||
"data": {
|
||||
"name": "[%key:common::config_flow::data::name%]",
|
||||
"prompt": "[%key:common::config_flow::data::prompt%]",
|
||||
"tts_speed": "Speed"
|
||||
},
|
||||
"data_description": {
|
||||
"prompt": "Control aspects of speech, including accent, emotional range, intonation, impressions, speed of speech, tone, whispering, and more",
|
||||
"tts_speed": "Speed of the generated speech"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"exceptions": {
|
||||
|
||||
194
homeassistant/components/openai_conversation/tts.py
Normal file
194
homeassistant/components/openai_conversation/tts.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Text to speech support for OpenAI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai import OpenAIError
|
||||
from propcache.api import cached_property
|
||||
|
||||
from homeassistant.components.tts import (
|
||||
ATTR_PREFERRED_FORMAT,
|
||||
ATTR_VOICE,
|
||||
TextToSpeechEntity,
|
||||
TtsAudioType,
|
||||
Voice,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from .const import CONF_CHAT_MODEL, CONF_PROMPT, CONF_TTS_SPEED, RECOMMENDED_TTS_SPEED
|
||||
from .entity import OpenAIBaseLLMEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import OpenAIConfigEntry
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: OpenAIConfigEntry,
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up TTS entities."""
|
||||
for subentry in config_entry.subentries.values():
|
||||
if subentry.subentry_type != "tts":
|
||||
continue
|
||||
|
||||
async_add_entities(
|
||||
[OpenAITTSEntity(config_entry, subentry)],
|
||||
config_subentry_id=subentry.subentry_id,
|
||||
)
|
||||
|
||||
|
||||
class OpenAITTSEntity(TextToSpeechEntity, OpenAIBaseLLMEntity):
|
||||
"""OpenAI TTS entity."""
|
||||
|
||||
_attr_supported_options = [ATTR_VOICE, ATTR_PREFERRED_FORMAT]
|
||||
# https://platform.openai.com/docs/guides/text-to-speech#supported-languages
|
||||
# The model may also generate the audio in different languages but with lower quality
|
||||
_attr_supported_languages = [
|
||||
"af-ZA", # Afrikaans
|
||||
"ar-SA", # Arabic
|
||||
"hy-AM", # Armenian
|
||||
"az-AZ", # Azerbaijani
|
||||
"be-BY", # Belarusian
|
||||
"bs-BA", # Bosnian
|
||||
"bg-BG", # Bulgarian
|
||||
"ca-ES", # Catalan
|
||||
"zh-CN", # Chinese (Mandarin)
|
||||
"hr-HR", # Croatian
|
||||
"cs-CZ", # Czech
|
||||
"da-DK", # Danish
|
||||
"nl-NL", # Dutch
|
||||
"en-US", # English
|
||||
"et-EE", # Estonian
|
||||
"fi-FI", # Finnish
|
||||
"fr-FR", # French
|
||||
"gl-ES", # Galician
|
||||
"de-DE", # German
|
||||
"el-GR", # Greek
|
||||
"he-IL", # Hebrew
|
||||
"hi-IN", # Hindi
|
||||
"hu-HU", # Hungarian
|
||||
"is-IS", # Icelandic
|
||||
"id-ID", # Indonesian
|
||||
"it-IT", # Italian
|
||||
"ja-JP", # Japanese
|
||||
"kn-IN", # Kannada
|
||||
"kk-KZ", # Kazakh
|
||||
"ko-KR", # Korean
|
||||
"lv-LV", # Latvian
|
||||
"lt-LT", # Lithuanian
|
||||
"mk-MK", # Macedonian
|
||||
"ms-MY", # Malay
|
||||
"mr-IN", # Marathi
|
||||
"mi-NZ", # Maori
|
||||
"ne-NP", # Nepali
|
||||
"no-NO", # Norwegian
|
||||
"fa-IR", # Persian
|
||||
"pl-PL", # Polish
|
||||
"pt-PT", # Portuguese
|
||||
"ro-RO", # Romanian
|
||||
"ru-RU", # Russian
|
||||
"sr-RS", # Serbian
|
||||
"sk-SK", # Slovak
|
||||
"sl-SI", # Slovenian
|
||||
"es-ES", # Spanish
|
||||
"sw-KE", # Swahili
|
||||
"sv-SE", # Swedish
|
||||
"fil-PH", # Tagalog (Filipino)
|
||||
"ta-IN", # Tamil
|
||||
"th-TH", # Thai
|
||||
"tr-TR", # Turkish
|
||||
"uk-UA", # Ukrainian
|
||||
"ur-PK", # Urdu
|
||||
"vi-VN", # Vietnamese
|
||||
"cy-GB", # Welsh
|
||||
]
|
||||
# Unused, but required by base class.
|
||||
# The models detect the input language automatically.
|
||||
_attr_default_language = "en-US"
|
||||
|
||||
# https://platform.openai.com/docs/guides/text-to-speech#voice-options
|
||||
_supported_voices = [
|
||||
Voice(voice.lower(), voice)
|
||||
for voice in (
|
||||
"Marin",
|
||||
"Cedar",
|
||||
"Alloy",
|
||||
"Ash",
|
||||
"Ballad",
|
||||
"Coral",
|
||||
"Echo",
|
||||
"Fable",
|
||||
"Nova",
|
||||
"Onyx",
|
||||
"Sage",
|
||||
"Shimmer",
|
||||
"Verse",
|
||||
)
|
||||
]
|
||||
|
||||
_supported_formats = ["mp3", "opus", "aac", "flac", "wav", "pcm"]
|
||||
|
||||
_attr_has_entity_name = False
|
||||
|
||||
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
|
||||
"""Initialize the entity."""
|
||||
super().__init__(entry, subentry)
|
||||
self._attr_name = subentry.title
|
||||
|
||||
@callback
|
||||
def async_get_supported_voices(self, language: str) -> list[Voice]:
|
||||
"""Return a list of supported voices for a language."""
|
||||
return self._supported_voices
|
||||
|
||||
@cached_property
|
||||
def default_options(self) -> Mapping[str, Any]:
|
||||
"""Return a mapping with the default options."""
|
||||
return {
|
||||
ATTR_VOICE: self._supported_voices[0].voice_id,
|
||||
ATTR_PREFERRED_FORMAT: "mp3",
|
||||
}
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from the engine."""
|
||||
|
||||
options = {**self.subentry.data, **options}
|
||||
client = self.entry.runtime_data
|
||||
|
||||
response_format = options[ATTR_PREFERRED_FORMAT]
|
||||
if response_format not in self._supported_formats:
|
||||
# common aliases
|
||||
if response_format == "ogg":
|
||||
response_format = "opus"
|
||||
elif response_format == "raw":
|
||||
response_format = "pcm"
|
||||
else:
|
||||
response_format = self.default_options[ATTR_PREFERRED_FORMAT]
|
||||
|
||||
try:
|
||||
async with client.audio.speech.with_streaming_response.create(
|
||||
model=options[CONF_CHAT_MODEL],
|
||||
voice=options[ATTR_VOICE],
|
||||
input=message,
|
||||
instructions=str(options.get(CONF_PROMPT)),
|
||||
speed=options.get(CONF_TTS_SPEED, RECOMMENDED_TTS_SPEED),
|
||||
response_format=response_format,
|
||||
) as response:
|
||||
response_data = bytearray()
|
||||
async for chunk in response.iter_bytes():
|
||||
response_data.extend(chunk)
|
||||
except OpenAIError as exc:
|
||||
_LOGGER.exception("Error during TTS")
|
||||
raise HomeAssistantError(exc) from exc
|
||||
|
||||
return response_format, bytes(response_data)
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from openai.types import ResponseFormatText
|
||||
from openai.types.responses import (
|
||||
@@ -24,7 +24,9 @@ from homeassistant.components.openai_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_TTS_NAME,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_TTS_OPTIONS,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigSubentryData
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
@@ -53,7 +55,7 @@ def mock_config_entry(
|
||||
"api_key": "bla",
|
||||
},
|
||||
version=2,
|
||||
minor_version=3,
|
||||
minor_version=5,
|
||||
subentries_data=[
|
||||
ConfigSubentryData(
|
||||
data=mock_conversation_subentry_data,
|
||||
@@ -67,6 +69,12 @@ def mock_config_entry(
|
||||
title=DEFAULT_AI_TASK_NAME,
|
||||
unique_id=None,
|
||||
),
|
||||
ConfigSubentryData(
|
||||
data=RECOMMENDED_TTS_OPTIONS,
|
||||
subentry_type="tts",
|
||||
title=DEFAULT_TTS_NAME,
|
||||
unique_id=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
@@ -209,3 +217,36 @@ def mock_create_stream() -> Generator[AsyncMock]:
|
||||
)
|
||||
|
||||
yield mock_create
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_create_speech() -> Generator[MagicMock]:
|
||||
"""Mock stream response."""
|
||||
|
||||
class AsyncIterBytesHelper:
|
||||
def __init__(self, chunks) -> None:
|
||||
self.chunks = chunks
|
||||
self.index = 0
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.index >= len(self.chunks):
|
||||
raise StopAsyncIteration
|
||||
chunk = self.chunks[self.index]
|
||||
self.index += 1
|
||||
return chunk
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__.return_value = mock_response
|
||||
mock_create = MagicMock(side_effect=lambda **kwargs: mock_cm)
|
||||
with patch(
|
||||
"openai.resources.audio.speech.async_to_custom_streamed_response_wrapper",
|
||||
return_value=mock_create,
|
||||
):
|
||||
mock_response.iter_bytes.side_effect = lambda **kwargs: AsyncIterBytesHelper(
|
||||
mock_create.return_value
|
||||
)
|
||||
yield mock_create
|
||||
|
||||
@@ -22,6 +22,7 @@ from homeassistant.components.openai_conversation.const import (
|
||||
CONF_RECOMMENDED,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
CONF_TTS_SPEED,
|
||||
CONF_VERBOSITY,
|
||||
CONF_WEB_SEARCH,
|
||||
CONF_WEB_SEARCH_CITY,
|
||||
@@ -33,12 +34,14 @@ from homeassistant.components.openai_conversation.const import (
|
||||
CONF_WEB_SEARCH_USER_LOCATION,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_REASONING_SUMMARY,
|
||||
RECOMMENDED_TOP_P,
|
||||
RECOMMENDED_TTS_OPTIONS,
|
||||
)
|
||||
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
@@ -97,6 +100,12 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
"title": DEFAULT_AI_TASK_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
{
|
||||
"subentry_type": "tts",
|
||||
"data": RECOMMENDED_TTS_OPTIONS,
|
||||
"title": DEFAULT_TTS_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
]
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
@@ -945,8 +954,8 @@ async def test_creating_ai_task_subentry(
|
||||
) -> None:
|
||||
"""Test creating an AI task subentry."""
|
||||
old_subentries = set(mock_config_entry.subentries)
|
||||
# Original conversation + original ai_task
|
||||
assert len(mock_config_entry.subentries) == 2
|
||||
# Original conversation + original ai_task + original tts
|
||||
assert len(mock_config_entry.subentries) == 3
|
||||
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "ai_task_data"),
|
||||
@@ -973,8 +982,8 @@ async def test_creating_ai_task_subentry(
|
||||
}
|
||||
|
||||
assert (
|
||||
len(mock_config_entry.subentries) == 3
|
||||
) # Original conversation + original ai_task + new ai_task
|
||||
len(mock_config_entry.subentries) == 4
|
||||
) # Original conversation + original tts + original ai_task + new ai_task
|
||||
|
||||
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
||||
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
||||
@@ -1058,6 +1067,91 @@ async def test_creating_ai_task_subentry_advanced(
|
||||
}
|
||||
|
||||
|
||||
async def test_creating_tts_subentry(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test creating a TTS subentry."""
|
||||
old_subentries = set(mock_config_entry.subentries)
|
||||
# Original conversation + original ai_task + original tts
|
||||
assert len(mock_config_entry.subentries) == 3
|
||||
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "tts"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
assert result.get("type") is FlowResultType.FORM
|
||||
assert result.get("step_id") == "init"
|
||||
assert not result.get("errors")
|
||||
|
||||
result = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
"name": "Custom TTS",
|
||||
CONF_PROMPT: "Speak like a drunk pirate",
|
||||
CONF_TTS_SPEED: 0.85,
|
||||
},
|
||||
)
|
||||
|
||||
assert result.get("type") is FlowResultType.CREATE_ENTRY
|
||||
assert result.get("title") == "Custom TTS"
|
||||
assert result.get("data") == {
|
||||
CONF_PROMPT: "Speak like a drunk pirate",
|
||||
CONF_TTS_SPEED: 0.85,
|
||||
CONF_CHAT_MODEL: "gpt-4o-mini-tts",
|
||||
}
|
||||
|
||||
assert (
|
||||
len(mock_config_entry.subentries) == 4
|
||||
) # Original conversation + original ai_task + original tts + new tts
|
||||
|
||||
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
||||
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
||||
assert new_subentry.subentry_type == "tts"
|
||||
assert new_subentry.title == "Custom TTS"
|
||||
|
||||
|
||||
async def test_tts_subentry_not_loaded(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test creating a TTS subentry when entry is not loaded."""
|
||||
# Don't call mock_init_component to simulate not loaded state
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "tts"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
assert result.get("type") is FlowResultType.ABORT
|
||||
assert result.get("reason") == "entry_not_loaded"
|
||||
|
||||
|
||||
async def test_tts_reconfigure(
|
||||
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||
) -> None:
|
||||
"""Test the tts subentry reconfigure flow with."""
|
||||
subentry = [
|
||||
s for s in mock_config_entry.subentries.values() if s.subentry_type == "tts"
|
||||
][0]
|
||||
subentry_flow = await mock_config_entry.start_subentry_reconfigure_flow(
|
||||
hass, subentry.subentry_id
|
||||
)
|
||||
options = await hass.config_entries.subentries.async_configure(
|
||||
subentry_flow["flow_id"],
|
||||
{
|
||||
"prompt": "Speak like a pirate",
|
||||
"tts_speed": 0.5,
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert options["type"] is FlowResultType.ABORT
|
||||
assert options["reason"] == "reconfigure_successful"
|
||||
assert subentry.data["prompt"] == "Speak like a pirate"
|
||||
assert subentry.data["tts_speed"] == 0.5
|
||||
|
||||
|
||||
async def test_reauth(hass: HomeAssistant) -> None:
|
||||
"""Test we can reauthenticate."""
|
||||
# Pretend we already set up a config entry.
|
||||
|
||||
@@ -21,9 +21,11 @@ from homeassistant.components.openai_conversation import CONF_CHAT_MODEL
|
||||
from homeassistant.components.openai_conversation.const import (
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_TTS_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
RECOMMENDED_TTS_OPTIONS,
|
||||
)
|
||||
from homeassistant.config_entries import (
|
||||
SOURCE_REAUTH,
|
||||
@@ -661,20 +663,23 @@ async def test_migration_from_v1(
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_config_entry.version == 2
|
||||
assert mock_config_entry.minor_version == 4
|
||||
assert mock_config_entry.minor_version == 5
|
||||
assert mock_config_entry.data == {"api_key": "1234"}
|
||||
assert mock_config_entry.options == {}
|
||||
|
||||
assert len(mock_config_entry.subentries) == 2
|
||||
assert len(mock_config_entry.subentries) == 3
|
||||
|
||||
# Find the conversation subentry
|
||||
conversation_subentry = None
|
||||
ai_task_subentry = None
|
||||
tts_subentry = None
|
||||
for subentry in mock_config_entry.subentries.values():
|
||||
if subentry.subentry_type == "conversation":
|
||||
conversation_subentry = subentry
|
||||
elif subentry.subentry_type == "ai_task_data":
|
||||
ai_task_subentry = subentry
|
||||
elif subentry.subentry_type == "tts":
|
||||
tts_subentry = subentry
|
||||
assert conversation_subentry is not None
|
||||
assert conversation_subentry.unique_id is None
|
||||
assert conversation_subentry.title == "ChatGPT"
|
||||
@@ -686,6 +691,11 @@ async def test_migration_from_v1(
|
||||
assert ai_task_subentry.title == DEFAULT_AI_TASK_NAME
|
||||
assert ai_task_subentry.subentry_type == "ai_task_data"
|
||||
|
||||
assert tts_subentry is not None
|
||||
assert tts_subentry.unique_id is None
|
||||
assert tts_subentry.title == DEFAULT_TTS_NAME
|
||||
assert tts_subentry.subentry_type == "tts"
|
||||
|
||||
# Use conversation subentry for the rest of the assertions
|
||||
subentry = conversation_subentry
|
||||
|
||||
@@ -790,9 +800,9 @@ async def test_migration_from_v1_with_multiple_keys(
|
||||
|
||||
for idx, entry in enumerate(entries):
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 4
|
||||
assert entry.minor_version == 5
|
||||
assert not entry.options
|
||||
assert len(entry.subentries) == 2
|
||||
assert len(entry.subentries) == 3
|
||||
|
||||
conversation_subentry = None
|
||||
for subentry in entry.subentries.values():
|
||||
@@ -895,11 +905,11 @@ async def test_migration_from_v1_with_same_keys(
|
||||
|
||||
entry = entries[0]
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 4
|
||||
assert entry.minor_version == 5
|
||||
assert not entry.options
|
||||
assert (
|
||||
len(entry.subentries) == 3
|
||||
) # Two conversation subentries + one AI task subentry
|
||||
len(entry.subentries) == 4
|
||||
) # Two conversation subentries + one AI task subentry + one TTS subentry
|
||||
|
||||
# Check both conversation subentries exist with correct data
|
||||
conversation_subentries = [
|
||||
@@ -908,9 +918,13 @@ async def test_migration_from_v1_with_same_keys(
|
||||
ai_task_subentries = [
|
||||
sub for sub in entry.subentries.values() if sub.subentry_type == "ai_task_data"
|
||||
]
|
||||
tts_subentries = [
|
||||
sub for sub in entry.subentries.values() if sub.subentry_type == "tts"
|
||||
]
|
||||
|
||||
assert len(conversation_subentries) == 2
|
||||
assert len(ai_task_subentries) == 1
|
||||
assert len(tts_subentries) == 1
|
||||
|
||||
titles = [sub.title for sub in conversation_subentries]
|
||||
assert "ChatGPT" in titles
|
||||
@@ -1098,10 +1112,12 @@ async def test_migration_from_v1_disabled(
|
||||
entry = entries[0]
|
||||
assert entry.disabled_by is merged_config_entry_disabled_by
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 4
|
||||
assert entry.minor_version == (
|
||||
4 if merged_config_entry_disabled_by is not None else 5
|
||||
)
|
||||
assert not entry.options
|
||||
assert entry.title == "OpenAI Conversation"
|
||||
assert len(entry.subentries) == 3
|
||||
assert len(entry.subentries) == (3 if entry.minor_version == 4 else 4)
|
||||
conversation_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
@@ -1120,6 +1136,17 @@ async def test_migration_from_v1_disabled(
|
||||
assert len(ai_task_subentries) == 1
|
||||
assert ai_task_subentries[0].data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
assert ai_task_subentries[0].title == DEFAULT_AI_TASK_NAME
|
||||
tts_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "tts"
|
||||
]
|
||||
if entry.minor_version == 4:
|
||||
assert len(tts_subentries) == 0
|
||||
else:
|
||||
assert len(tts_subentries) == 1
|
||||
assert tts_subentries[0].data == RECOMMENDED_TTS_OPTIONS
|
||||
assert tts_subentries[0].title == DEFAULT_TTS_NAME
|
||||
|
||||
assert not device_registry.async_get_device(
|
||||
identifiers={(DOMAIN, mock_config_entry.entry_id)}
|
||||
@@ -1250,10 +1277,10 @@ async def test_migration_from_v2_1(
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 4
|
||||
assert entry.minor_version == 5
|
||||
assert not entry.options
|
||||
assert entry.title == "ChatGPT"
|
||||
assert len(entry.subentries) == 3 # 2 conversation + 1 AI task
|
||||
assert len(entry.subentries) == 4 # 2 conversation + 1 AI task + 1 TTS
|
||||
conversation_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
@@ -1264,8 +1291,14 @@ async def test_migration_from_v2_1(
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "ai_task_data"
|
||||
]
|
||||
tts_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "tts"
|
||||
]
|
||||
assert len(conversation_subentries) == 2
|
||||
assert len(ai_task_subentries) == 1
|
||||
assert len(tts_subentries) == 1
|
||||
for subentry in conversation_subentries:
|
||||
assert subentry.subentry_type == "conversation"
|
||||
assert subentry.data == options
|
||||
@@ -1329,7 +1362,7 @@ async def test_devices(
|
||||
devices = dr.async_entries_for_config_entry(
|
||||
device_registry, mock_config_entry.entry_id
|
||||
)
|
||||
assert len(devices) == 2 # One for conversation, one for AI task
|
||||
assert len(devices) == 3 # One for conversation, one for AI task, one for TTS
|
||||
|
||||
# Use the first device for snapshot comparison
|
||||
device = devices[0]
|
||||
@@ -1386,10 +1419,10 @@ async def test_migration_from_v2_2(
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 4
|
||||
assert entry.minor_version == 5
|
||||
assert not entry.options
|
||||
assert entry.title == "ChatGPT"
|
||||
assert len(entry.subentries) == 2
|
||||
assert len(entry.subentries) == 3
|
||||
|
||||
# Check conversation subentry is still there
|
||||
conversation_subentries = [
|
||||
@@ -1431,7 +1464,7 @@ async def test_migration_from_v2_2(
|
||||
DeviceEntryDisabler.CONFIG_ENTRY,
|
||||
RegistryEntryDisabler.CONFIG_ENTRY,
|
||||
True,
|
||||
4,
|
||||
5,
|
||||
None,
|
||||
DeviceEntryDisabler.USER,
|
||||
RegistryEntryDisabler.DEVICE,
|
||||
@@ -1441,7 +1474,7 @@ async def test_migration_from_v2_2(
|
||||
DeviceEntryDisabler.USER,
|
||||
RegistryEntryDisabler.DEVICE,
|
||||
True,
|
||||
4,
|
||||
5,
|
||||
None,
|
||||
DeviceEntryDisabler.USER,
|
||||
RegistryEntryDisabler.DEVICE,
|
||||
@@ -1451,7 +1484,7 @@ async def test_migration_from_v2_2(
|
||||
DeviceEntryDisabler.USER,
|
||||
RegistryEntryDisabler.USER,
|
||||
True,
|
||||
4,
|
||||
5,
|
||||
None,
|
||||
DeviceEntryDisabler.USER,
|
||||
RegistryEntryDisabler.USER,
|
||||
@@ -1461,7 +1494,7 @@ async def test_migration_from_v2_2(
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
4,
|
||||
5,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
@@ -1596,3 +1629,95 @@ async def test_migrate_entry_from_v2_3(
|
||||
assert mock_config_entry.disabled_by == config_entry_disabled_by_after_migration
|
||||
assert conversation_device.disabled_by == device_disabled_by_after_migration
|
||||
assert conversation_entity.disabled_by == entity_disabled_by_after_migration
|
||||
|
||||
|
||||
async def test_migration_from_v2_4(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test migration from version 2.4."""
|
||||
# Create a v2.4 config entry with a conversation and AI Task subentries
|
||||
conversation_options = {
|
||||
"recommended": True,
|
||||
"llm_hass_api": ["assist"],
|
||||
"prompt": "You are a helpful assistant",
|
||||
"chat_model": "gpt-4o-mini",
|
||||
}
|
||||
ai_task_options = {
|
||||
"recommended": True,
|
||||
"chat_model": "gpt-5-mini",
|
||||
}
|
||||
mock_config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={"api_key": "1234"},
|
||||
entry_id="mock_entry_id",
|
||||
version=2,
|
||||
minor_version=4,
|
||||
subentries_data=[
|
||||
ConfigSubentryData(
|
||||
data=conversation_options,
|
||||
subentry_id="mock_id_1",
|
||||
subentry_type="conversation",
|
||||
title="ChatGPT",
|
||||
unique_id=None,
|
||||
),
|
||||
ConfigSubentryData(
|
||||
data=ai_task_options,
|
||||
subentry_id="mock_id_2",
|
||||
subentry_type="ai_task_data",
|
||||
title="OpenAI AI Task",
|
||||
unique_id=None,
|
||||
),
|
||||
],
|
||||
title="ChatGPT",
|
||||
)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
# Run migration
|
||||
with patch(
|
||||
"homeassistant.components.openai_conversation.async_setup_entry",
|
||||
return_value=True,
|
||||
):
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
entries = hass.config_entries.async_entries(DOMAIN)
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 5
|
||||
assert not entry.options
|
||||
assert entry.title == "ChatGPT"
|
||||
assert len(entry.subentries) == 3
|
||||
|
||||
# Check conversation subentry is still there
|
||||
conversation_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "conversation"
|
||||
]
|
||||
assert len(conversation_subentries) == 1
|
||||
conversation_subentry = conversation_subentries[0]
|
||||
assert conversation_subentry.data == conversation_options
|
||||
|
||||
# Check AI Task subentry is still there
|
||||
ai_task_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "ai_task_data"
|
||||
]
|
||||
assert len(ai_task_subentries) == 1
|
||||
ai_task_subentry = ai_task_subentries[0]
|
||||
assert ai_task_subentry.data == ai_task_options
|
||||
|
||||
# Check TTS subentry was added
|
||||
tts_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "tts"
|
||||
]
|
||||
assert len(tts_subentries) == 1
|
||||
tts_subentry = tts_subentries[0]
|
||||
assert tts_subentry.data == {"chat_model": "gpt-4o-mini-tts", "prompt": ""}
|
||||
assert tts_subentry.title == "OpenAI TTS"
|
||||
|
||||
164
tests/components/openai_conversation/test_tts.py
Normal file
164
tests/components/openai_conversation/test_tts.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Test TTS platform of OpenAI Conversation integration."""
|
||||
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
from openai import RateLimitError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
DOMAIN as DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
)
|
||||
from homeassistant.const import ATTR_ENTITY_ID
|
||||
from homeassistant.core import HomeAssistant, ServiceCall
|
||||
from homeassistant.core_config import async_process_ha_core_config
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
|
||||
from tests.common import MockConfigEntry, async_mock_service
|
||||
from tests.components.tts.common import retrieve_media
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def tts_mutagen_mock_fixture_autouse(tts_mutagen_mock: MagicMock) -> None:
|
||||
"""Mock writing tags."""
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
|
||||
"""Mock the TTS cache dir with empty dir."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def calls(hass: HomeAssistant) -> list[ServiceCall]:
|
||||
"""Mock media player calls."""
|
||||
return async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_internal_url(hass: HomeAssistant) -> None:
|
||||
"""Set up internal url."""
|
||||
await async_process_ha_core_config(
|
||||
hass, {"internal_url": "http://example.local:8123"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"service_data",
|
||||
[
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.openai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {},
|
||||
},
|
||||
{
|
||||
ATTR_ENTITY_ID: "tts.openai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2"},
|
||||
},
|
||||
],
|
||||
)
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_tts(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_create_speech: MagicMock,
|
||||
entity_registry: er.EntityRegistry,
|
||||
calls: list[ServiceCall],
|
||||
service_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test text to speech generation."""
|
||||
entity_id = "tts.openai_tts"
|
||||
|
||||
# Ensure entity is linked to the subentry
|
||||
entity_entry = entity_registry.async_get(entity_id)
|
||||
tts_entry = next(
|
||||
iter(
|
||||
entry
|
||||
for entry in mock_config_entry.subentries.values()
|
||||
if entry.subentry_type == "tts"
|
||||
)
|
||||
)
|
||||
assert entity_entry is not None
|
||||
assert entity_entry.config_entry_id == mock_config_entry.entry_id
|
||||
assert entity_entry.config_subentry_id == tts_entry.subentry_id
|
||||
|
||||
# Mock the OpenAI response stream
|
||||
mock_create_speech.return_value = [b"mock aud", b"io data"]
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
"speak",
|
||||
service_data,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.OK
|
||||
)
|
||||
voice_id = service_data[tts.ATTR_OPTIONS].get(tts.ATTR_VOICE, "marin")
|
||||
mock_create_speech.assert_called_once_with(
|
||||
model="gpt-4o-mini-tts",
|
||||
voice=voice_id,
|
||||
input="There is a person at the front door.",
|
||||
instructions="",
|
||||
speed=1.0,
|
||||
response_format="mp3",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_tts_error(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_create_speech: MagicMock,
|
||||
entity_registry: er.EntityRegistry,
|
||||
calls: list[ServiceCall],
|
||||
) -> None:
|
||||
"""Test exception handling during text to speech generation."""
|
||||
# Mock the OpenAI response stream
|
||||
mock_create_speech.side_effect = RateLimitError(
|
||||
response=httpx.Response(status_code=429, request=""),
|
||||
body=None,
|
||||
message=None,
|
||||
)
|
||||
|
||||
service_data = {
|
||||
ATTR_ENTITY_ID: "tts.openai_tts",
|
||||
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
|
||||
tts.ATTR_MESSAGE: "There is a person at the front door.",
|
||||
tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"},
|
||||
}
|
||||
|
||||
await hass.services.async_call(
|
||||
tts.DOMAIN,
|
||||
"speak",
|
||||
service_data,
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert (
|
||||
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
== HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
)
|
||||
mock_create_speech.assert_called_once_with(
|
||||
model="gpt-4o-mini-tts",
|
||||
voice="voice1",
|
||||
input="There is a person at the front door.",
|
||||
instructions="",
|
||||
speed=1.0,
|
||||
response_format="mp3",
|
||||
)
|
||||
Reference in New Issue
Block a user