mirror of
https://github.com/home-assistant/core.git
synced 2025-12-22 03:49:36 +00:00
344 lines
11 KiB
Python
344 lines
11 KiB
Python
"""Tests for the Home Assistant Cloud AI Task entity."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from hass_nabucasa.llm import (
|
|
LLMAuthenticationError,
|
|
LLMError,
|
|
LLMImageAttachment,
|
|
LLMRateLimitError,
|
|
LLMResponseError,
|
|
LLMServiceError,
|
|
)
|
|
from PIL import Image
|
|
import pytest
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.components import ai_task, conversation
|
|
from homeassistant.components.cloud.ai_task import (
|
|
CloudAITaskEntity,
|
|
async_prepare_image_generation_attachments,
|
|
)
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.exceptions import HomeAssistantError
|
|
|
|
from tests.common import MockConfigEntry
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_cloud_ai_task_entity(hass: HomeAssistant) -> CloudAITaskEntity:
|
|
"""Return a CloudAITaskEntity with a mocked cloud LLM."""
|
|
cloud = MagicMock()
|
|
cloud.llm = MagicMock(
|
|
async_generate_image=AsyncMock(),
|
|
async_edit_image=AsyncMock(),
|
|
)
|
|
cloud.is_logged_in = True
|
|
cloud.valid_subscription = True
|
|
entry = MockConfigEntry(domain="cloud")
|
|
entry.add_to_hass(hass)
|
|
entity = CloudAITaskEntity(cloud, entry)
|
|
entity.entity_id = "ai_task.cloud_ai_task"
|
|
entity.hass = hass
|
|
return entity
|
|
|
|
|
|
@pytest.fixture(name="mock_handle_chat_log")
|
|
def mock_handle_chat_log_fixture() -> AsyncMock:
|
|
"""Patch the chat log handler."""
|
|
with patch(
|
|
"homeassistant.components.cloud.ai_task.CloudAITaskEntity._async_handle_chat_log",
|
|
AsyncMock(),
|
|
) as mock:
|
|
yield mock
|
|
|
|
|
|
@pytest.fixture(name="mock_prepare_generation_attachments")
|
|
def mock_prepare_generation_attachments_fixture() -> AsyncMock:
|
|
"""Patch image generation attachment preparation."""
|
|
with patch(
|
|
"homeassistant.components.cloud.ai_task.async_prepare_image_generation_attachments",
|
|
AsyncMock(),
|
|
) as mock:
|
|
yield mock
|
|
|
|
|
|
async def test_prepare_image_generation_attachments(
|
|
hass: HomeAssistant, tmp_path: Path
|
|
) -> None:
|
|
"""Test preparing attachments for image generation."""
|
|
image_path = tmp_path / "snapshot.jpg"
|
|
Image.new("RGB", (2, 2), "red").save(image_path, "JPEG")
|
|
|
|
attachments = [
|
|
conversation.Attachment(
|
|
media_content_id="media-source://media/snapshot.jpg",
|
|
mime_type="image/jpeg",
|
|
path=image_path,
|
|
)
|
|
]
|
|
|
|
result = await async_prepare_image_generation_attachments(hass, attachments)
|
|
|
|
assert len(result) == 1
|
|
attachment = result[0]
|
|
assert attachment["filename"] == "snapshot.jpg"
|
|
assert attachment["mime_type"] == "image/png"
|
|
assert attachment["data"].startswith(b"\x89PNG")
|
|
|
|
|
|
async def test_prepare_image_generation_attachments_only_images(
|
|
hass: HomeAssistant, tmp_path: Path
|
|
) -> None:
|
|
"""Test non image attachments are rejected."""
|
|
doc_path = tmp_path / "context.txt"
|
|
doc_path.write_text("context")
|
|
|
|
attachments = [
|
|
conversation.Attachment(
|
|
media_content_id="media-source://media/context.txt",
|
|
mime_type="text/plain",
|
|
path=doc_path,
|
|
)
|
|
]
|
|
|
|
with pytest.raises(
|
|
HomeAssistantError,
|
|
match="Only image attachments are supported for image generation",
|
|
):
|
|
await async_prepare_image_generation_attachments(hass, attachments)
|
|
|
|
|
|
async def test_prepare_image_generation_attachments_missing_file(
|
|
hass: HomeAssistant, tmp_path: Path
|
|
) -> None:
|
|
"""Test missing attachments raise a helpful error."""
|
|
missing_path = tmp_path / "missing.png"
|
|
|
|
attachments = [
|
|
conversation.Attachment(
|
|
media_content_id="media-source://media/missing.png",
|
|
mime_type="image/png",
|
|
path=missing_path,
|
|
)
|
|
]
|
|
|
|
with pytest.raises(HomeAssistantError, match="`.*missing.png` does not exist"):
|
|
await async_prepare_image_generation_attachments(hass, attachments)
|
|
|
|
|
|
async def test_prepare_image_generation_attachments_processing_error(
|
|
hass: HomeAssistant, tmp_path: Path
|
|
) -> None:
|
|
"""Test invalid image data raises a processing error."""
|
|
broken_path = tmp_path / "broken.png"
|
|
broken_path.write_bytes(b"not-an-image")
|
|
|
|
attachments = [
|
|
conversation.Attachment(
|
|
media_content_id="media-source://media/broken.png",
|
|
mime_type="image/png",
|
|
path=broken_path,
|
|
)
|
|
]
|
|
|
|
with pytest.raises(
|
|
HomeAssistantError,
|
|
match="Failed to process image attachment",
|
|
):
|
|
await async_prepare_image_generation_attachments(hass, attachments)
|
|
|
|
|
|
async def test_generate_data_returns_text(
|
|
hass: HomeAssistant,
|
|
mock_cloud_ai_task_entity: CloudAITaskEntity,
|
|
mock_handle_chat_log: AsyncMock,
|
|
) -> None:
|
|
"""Test generating plain text data."""
|
|
chat_log = conversation.ChatLog(hass, "conversation-id")
|
|
chat_log.async_add_user_content(
|
|
conversation.UserContent(content="Tell me something")
|
|
)
|
|
task = ai_task.GenDataTask(name="Task", instructions="Say hi")
|
|
|
|
async def fake_handle(chat_type, log, task_name, structure):
|
|
"""Inject assistant output."""
|
|
assert chat_type == "ai_task"
|
|
log.async_add_assistant_content_without_tools(
|
|
conversation.AssistantContent(
|
|
agent_id=mock_cloud_ai_task_entity.entity_id or "",
|
|
content="Hello from the cloud",
|
|
)
|
|
)
|
|
|
|
mock_handle_chat_log.side_effect = fake_handle
|
|
result = await mock_cloud_ai_task_entity._async_generate_data(task, chat_log)
|
|
|
|
assert result.conversation_id == "conversation-id"
|
|
assert result.data == "Hello from the cloud"
|
|
|
|
|
|
async def test_generate_data_returns_json(
|
|
hass: HomeAssistant,
|
|
mock_cloud_ai_task_entity: CloudAITaskEntity,
|
|
mock_handle_chat_log: AsyncMock,
|
|
) -> None:
|
|
"""Test generating structured data."""
|
|
chat_log = conversation.ChatLog(hass, "conversation-id")
|
|
chat_log.async_add_user_content(conversation.UserContent(content="List names"))
|
|
task = ai_task.GenDataTask(
|
|
name="Task",
|
|
instructions="Return JSON",
|
|
structure=vol.Schema({vol.Required("names"): [str]}),
|
|
)
|
|
|
|
async def fake_handle(chat_type, log, task_name, structure):
|
|
log.async_add_assistant_content_without_tools(
|
|
conversation.AssistantContent(
|
|
agent_id=mock_cloud_ai_task_entity.entity_id or "",
|
|
content='{"names": ["A", "B"]}',
|
|
)
|
|
)
|
|
|
|
mock_handle_chat_log.side_effect = fake_handle
|
|
result = await mock_cloud_ai_task_entity._async_generate_data(task, chat_log)
|
|
|
|
assert result.data == {"names": ["A", "B"]}
|
|
|
|
|
|
async def test_generate_data_invalid_json(
|
|
hass: HomeAssistant,
|
|
mock_cloud_ai_task_entity: CloudAITaskEntity,
|
|
mock_handle_chat_log: AsyncMock,
|
|
) -> None:
|
|
"""Test invalid JSON responses raise an error."""
|
|
chat_log = conversation.ChatLog(hass, "conversation-id")
|
|
chat_log.async_add_user_content(conversation.UserContent(content="List names"))
|
|
task = ai_task.GenDataTask(
|
|
name="Task",
|
|
instructions="Return JSON",
|
|
structure=vol.Schema({vol.Required("names"): [str]}),
|
|
)
|
|
|
|
async def fake_handle(chat_type, log, task_name, structure):
|
|
log.async_add_assistant_content_without_tools(
|
|
conversation.AssistantContent(
|
|
agent_id=mock_cloud_ai_task_entity.entity_id or "",
|
|
content="not-json",
|
|
)
|
|
)
|
|
|
|
mock_handle_chat_log.side_effect = fake_handle
|
|
with pytest.raises(
|
|
HomeAssistantError, match="Error with OpenAI structured response"
|
|
):
|
|
await mock_cloud_ai_task_entity._async_generate_data(task, chat_log)
|
|
|
|
|
|
async def test_generate_image_no_attachments(
|
|
hass: HomeAssistant, mock_cloud_ai_task_entity: CloudAITaskEntity
|
|
) -> None:
|
|
"""Test generating an image without attachments."""
|
|
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.return_value = {
|
|
"mime_type": "image/png",
|
|
"image_data": b"IMG",
|
|
"model": "mock-image",
|
|
"width": 1024,
|
|
"height": 768,
|
|
"revised_prompt": "Improved prompt",
|
|
}
|
|
task = ai_task.GenImageTask(name="Task", instructions="Draw something")
|
|
chat_log = conversation.ChatLog(hass, "conversation-id")
|
|
|
|
result = await mock_cloud_ai_task_entity._async_generate_image(task, chat_log)
|
|
|
|
assert result.image_data == b"IMG"
|
|
assert result.mime_type == "image/png"
|
|
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.assert_awaited_once_with(
|
|
prompt="Draw something"
|
|
)
|
|
|
|
|
|
async def test_generate_image_with_attachments(
|
|
hass: HomeAssistant,
|
|
mock_cloud_ai_task_entity: CloudAITaskEntity,
|
|
mock_prepare_generation_attachments: AsyncMock,
|
|
) -> None:
|
|
"""Test generating an edited image when attachments are provided."""
|
|
mock_cloud_ai_task_entity._cloud.llm.async_edit_image.return_value = {
|
|
"mime_type": "image/png",
|
|
"image_data": b"IMG",
|
|
}
|
|
task = ai_task.GenImageTask(
|
|
name="Task",
|
|
instructions="Edit this",
|
|
attachments=[
|
|
conversation.Attachment(
|
|
media_content_id="media-source://media/snapshot.png",
|
|
mime_type="image/png",
|
|
path=hass.config.path("snapshot.png"),
|
|
)
|
|
],
|
|
)
|
|
chat_log = conversation.ChatLog(hass, "conversation-id")
|
|
prepared_attachments = [
|
|
LLMImageAttachment(filename="snapshot.png", mime_type="image/png", data=b"IMG")
|
|
]
|
|
|
|
mock_prepare_generation_attachments.return_value = prepared_attachments
|
|
await mock_cloud_ai_task_entity._async_generate_image(task, chat_log)
|
|
|
|
mock_cloud_ai_task_entity._cloud.llm.async_edit_image.assert_awaited_once_with(
|
|
prompt="Edit this",
|
|
attachments=prepared_attachments,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("err", "expected_exception", "message"),
|
|
[
|
|
(
|
|
LLMAuthenticationError("auth"),
|
|
HomeAssistantError,
|
|
"Cloud LLM authentication failed",
|
|
),
|
|
(
|
|
LLMRateLimitError("limit"),
|
|
HomeAssistantError,
|
|
"Cloud LLM is rate limited",
|
|
),
|
|
(
|
|
LLMResponseError("bad response"),
|
|
HomeAssistantError,
|
|
"bad response",
|
|
),
|
|
(
|
|
LLMServiceError("service"),
|
|
HomeAssistantError,
|
|
"Error talking to Cloud LLM",
|
|
),
|
|
(
|
|
LLMError("generic"),
|
|
HomeAssistantError,
|
|
"generic",
|
|
),
|
|
],
|
|
)
|
|
async def test_generate_image_error_handling(
|
|
hass: HomeAssistant,
|
|
mock_cloud_ai_task_entity: CloudAITaskEntity,
|
|
err: Exception,
|
|
expected_exception: type[Exception],
|
|
message: str,
|
|
) -> None:
|
|
"""Test image generation error handling."""
|
|
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.side_effect = err
|
|
task = ai_task.GenImageTask(name="Task", instructions="Draw something")
|
|
chat_log = conversation.ChatLog(hass, "conversation-id")
|
|
|
|
with pytest.raises(expected_exception, match=message):
|
|
await mock_cloud_ai_task_entity._async_generate_image(task, chat_log)
|