1
0
mirror of https://github.com/home-assistant/core.git synced 2026-02-22 02:47:14 +00:00
Files
core/tests/components/fish_audio/test_tts.py
2026-02-17 14:37:17 +01:00

317 lines
9.5 KiB
Python

"""Tests for the Fish Audio TTS entity."""
from __future__ import annotations
from http import HTTPStatus
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
from fishaudio import RateLimitError
from fishaudio.exceptions import ServerError
import pytest
from homeassistant.components import tts
from homeassistant.components.fish_audio.const import CONF_BACKEND
from homeassistant.components.media_player import (
ATTR_MEDIA_CONTENT_ID,
DOMAIN as MP_DOMAIN,
SERVICE_PLAY_MEDIA,
)
from homeassistant.config_entries import ConfigSubentryData
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.core_config import async_process_ha_core_config
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
from tests.common import MockConfigEntry, async_mock_service
from tests.components.tts.common import retrieve_media
from tests.typing import ClientSessionGenerator
@pytest.fixture(autouse=True)
def tts_mutagen_mock_fixture_autouse(tts_mutagen_mock: MagicMock) -> None:
"""Mock writing tags."""
@pytest.fixture(autouse=True)
def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
"""Mock the TTS cache dir with empty dir."""
@pytest.fixture(autouse=True)
async def setup_internal_url(hass: HomeAssistant) -> None:
"""Set up internal url."""
await async_process_ha_core_config(
hass, {"internal_url": "http://example.local:8123"}
)
@pytest.fixture
async def calls(hass: HomeAssistant) -> list[ServiceCall]:
"""Mock media player calls."""
return async_mock_service(hass, MP_DOMAIN, SERVICE_PLAY_MEDIA)
async def test_tts_service_success(
hass: HomeAssistant,
mock_fishaudio_client: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test TTS service with successful audio generation."""
mock_config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
# Get the TTS entity
entity = hass.data[tts.DOMAIN].get_entity("tts.test_voice_test_voice")
assert entity is not None
# Test the TTS audio generation
extension, data = await entity.async_get_tts_audio(
message="Hello world",
language="en",
options={},
)
# Verify the result
assert extension == "mp3"
assert data == b"fake_audio_data"
# Verify the client was called
mock_fishaudio_client.tts.convert.assert_called_once()
async def test_tts_rate_limited(
hass: HomeAssistant,
mock_fishaudio_client: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test TTS service with rate limit error."""
mock_config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
# Get the TTS entity
entity = hass.data[tts.DOMAIN].get_entity("tts.test_voice_test_voice")
assert entity is not None
# Make tts.convert() fail with rate limit error
mock_fishaudio_client.tts.convert = AsyncMock(
side_effect=RateLimitError(429, "Rate limited")
)
# Test that the error is raised
with pytest.raises(HomeAssistantError, match="Rate limited"):
await entity.async_get_tts_audio(
message="Hello world",
language="en",
options={},
)
# Verify the client was called
mock_fishaudio_client.tts.convert.assert_called_once()
async def test_tts_missing_voice_id(
hass: HomeAssistant,
mock_fishaudio_client: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test TTS service raises ServiceValidationError when voice_id is missing."""
# Create a config entry with no voice_id
entry = MockConfigEntry(
domain="fish_audio",
data={"api_key": "test-key"},
unique_id="test_user",
subentries_data=[
ConfigSubentryData(
data={CONF_BACKEND: "s1"}, # Missing CONF_VOICE_ID
subentry_type="tts",
title="Test Voice",
subentry_id="test-sub-id",
unique_id="test-voice",
)
],
)
entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
# Get the TTS entity
entity = hass.data[tts.DOMAIN].get_entity("tts.test_voice_test_voice")
assert entity is not None
# Test that the error is raised
with pytest.raises(ServiceValidationError, match="Voice ID not configured"):
await entity.async_get_tts_audio(
message="Hello world",
language="en",
options={},
)
async def test_tts_supported_languages(
hass: HomeAssistant,
mock_fishaudio_client: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test TTS entity supported languages."""
mock_config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
# Get the TTS entity
entity = hass.data[tts.DOMAIN].get_entity("tts.test_voice_test_voice")
assert entity is not None
# Verify supported languages
assert entity.supported_languages == [
"Any",
"en",
"zh",
"de",
"ja",
"ar",
"fr",
"es",
"ko",
]
# Service-level integration tests
async def test_tts_service_speak(
hass: HomeAssistant,
mock_fishaudio_client: AsyncMock,
mock_config_entry: MockConfigEntry,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
) -> None:
"""Test TTS speak service call."""
mock_config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
await hass.services.async_call(
tts.DOMAIN,
"speak",
{
ATTR_ENTITY_ID: "tts.test_voice_test_voice",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "Hello world",
},
blocking=True,
)
assert len(calls) == 1
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
async def test_tts_service_speak_with_language(
hass: HomeAssistant,
mock_fishaudio_client: AsyncMock,
mock_config_entry: MockConfigEntry,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
) -> None:
"""Test TTS speak service call with language parameter."""
mock_config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
await hass.services.async_call(
tts.DOMAIN,
"speak",
{
ATTR_ENTITY_ID: "tts.test_voice_test_voice",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "Hola mundo",
tts.ATTR_LANGUAGE: "es",
},
blocking=True,
)
assert len(calls) == 1
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.OK
)
async def test_tts_service_speak_server_error(
hass: HomeAssistant,
mock_fishaudio_client: AsyncMock,
mock_config_entry: MockConfigEntry,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
) -> None:
"""Test TTS speak service call with server error."""
mock_config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
# Get the client from runtime_data and make it fail with ServerError
entry = hass.config_entries.async_get_entry(mock_config_entry.entry_id)
assert entry is not None
mock_fishaudio_client.tts.convert = AsyncMock(
side_effect=ServerError(500, "Internal server error")
)
await hass.services.async_call(
tts.DOMAIN,
"speak",
{
ATTR_ENTITY_ID: "tts.test_voice_test_voice",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "Test server error",
},
blocking=True,
)
assert len(calls) == 1
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.INTERNAL_SERVER_ERROR
)
async def test_tts_service_speak_rate_limit_error(
hass: HomeAssistant,
mock_fishaudio_client: AsyncMock,
mock_config_entry: MockConfigEntry,
hass_client: ClientSessionGenerator,
calls: list[ServiceCall],
) -> None:
"""Test TTS speak service call with rate limit error."""
mock_config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
# Get the client from runtime_data and make it fail with RateLimitError
entry = hass.config_entries.async_get_entry(mock_config_entry.entry_id)
assert entry is not None
mock_fishaudio_client.tts.convert = AsyncMock(
side_effect=RateLimitError(429, "Rate limit exceeded")
)
await hass.services.async_call(
tts.DOMAIN,
"speak",
{
ATTR_ENTITY_ID: "tts.test_voice_test_voice",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "Test rate limit error",
},
blocking=True,
)
assert len(calls) == 1
assert (
await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== HTTPStatus.INTERNAL_SERVER_ERROR
)