1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-22 03:49:36 +00:00
Files
core/tests/components/cloud/test_ai_task.py

344 lines
11 KiB
Python

"""Tests for the Home Assistant Cloud AI Task entity."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
from hass_nabucasa.llm import (
LLMAuthenticationError,
LLMError,
LLMImageAttachment,
LLMRateLimitError,
LLMResponseError,
LLMServiceError,
)
from PIL import Image
import pytest
import voluptuous as vol
from homeassistant.components import ai_task, conversation
from homeassistant.components.cloud.ai_task import (
CloudAITaskEntity,
async_prepare_image_generation_attachments,
)
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from tests.common import MockConfigEntry
@pytest.fixture
def mock_cloud_ai_task_entity(hass: HomeAssistant) -> CloudAITaskEntity:
"""Return a CloudAITaskEntity with a mocked cloud LLM."""
cloud = MagicMock()
cloud.llm = MagicMock(
async_generate_image=AsyncMock(),
async_edit_image=AsyncMock(),
)
cloud.is_logged_in = True
cloud.valid_subscription = True
entry = MockConfigEntry(domain="cloud")
entry.add_to_hass(hass)
entity = CloudAITaskEntity(cloud, entry)
entity.entity_id = "ai_task.cloud_ai_task"
entity.hass = hass
return entity
@pytest.fixture(name="mock_handle_chat_log")
def mock_handle_chat_log_fixture() -> AsyncMock:
"""Patch the chat log handler."""
with patch(
"homeassistant.components.cloud.ai_task.CloudAITaskEntity._async_handle_chat_log",
AsyncMock(),
) as mock:
yield mock
@pytest.fixture(name="mock_prepare_generation_attachments")
def mock_prepare_generation_attachments_fixture() -> AsyncMock:
"""Patch image generation attachment preparation."""
with patch(
"homeassistant.components.cloud.ai_task.async_prepare_image_generation_attachments",
AsyncMock(),
) as mock:
yield mock
async def test_prepare_image_generation_attachments(
hass: HomeAssistant, tmp_path: Path
) -> None:
"""Test preparing attachments for image generation."""
image_path = tmp_path / "snapshot.jpg"
Image.new("RGB", (2, 2), "red").save(image_path, "JPEG")
attachments = [
conversation.Attachment(
media_content_id="media-source://media/snapshot.jpg",
mime_type="image/jpeg",
path=image_path,
)
]
result = await async_prepare_image_generation_attachments(hass, attachments)
assert len(result) == 1
attachment = result[0]
assert attachment["filename"] == "snapshot.jpg"
assert attachment["mime_type"] == "image/png"
assert attachment["data"].startswith(b"\x89PNG")
async def test_prepare_image_generation_attachments_only_images(
hass: HomeAssistant, tmp_path: Path
) -> None:
"""Test non image attachments are rejected."""
doc_path = tmp_path / "context.txt"
doc_path.write_text("context")
attachments = [
conversation.Attachment(
media_content_id="media-source://media/context.txt",
mime_type="text/plain",
path=doc_path,
)
]
with pytest.raises(
HomeAssistantError,
match="Only image attachments are supported for image generation",
):
await async_prepare_image_generation_attachments(hass, attachments)
async def test_prepare_image_generation_attachments_missing_file(
hass: HomeAssistant, tmp_path: Path
) -> None:
"""Test missing attachments raise a helpful error."""
missing_path = tmp_path / "missing.png"
attachments = [
conversation.Attachment(
media_content_id="media-source://media/missing.png",
mime_type="image/png",
path=missing_path,
)
]
with pytest.raises(HomeAssistantError, match="`.*missing.png` does not exist"):
await async_prepare_image_generation_attachments(hass, attachments)
async def test_prepare_image_generation_attachments_processing_error(
hass: HomeAssistant, tmp_path: Path
) -> None:
"""Test invalid image data raises a processing error."""
broken_path = tmp_path / "broken.png"
broken_path.write_bytes(b"not-an-image")
attachments = [
conversation.Attachment(
media_content_id="media-source://media/broken.png",
mime_type="image/png",
path=broken_path,
)
]
with pytest.raises(
HomeAssistantError,
match="Failed to process image attachment",
):
await async_prepare_image_generation_attachments(hass, attachments)
async def test_generate_data_returns_text(
hass: HomeAssistant,
mock_cloud_ai_task_entity: CloudAITaskEntity,
mock_handle_chat_log: AsyncMock,
) -> None:
"""Test generating plain text data."""
chat_log = conversation.ChatLog(hass, "conversation-id")
chat_log.async_add_user_content(
conversation.UserContent(content="Tell me something")
)
task = ai_task.GenDataTask(name="Task", instructions="Say hi")
async def fake_handle(chat_type, log, task_name, structure):
"""Inject assistant output."""
assert chat_type == "ai_task"
log.async_add_assistant_content_without_tools(
conversation.AssistantContent(
agent_id=mock_cloud_ai_task_entity.entity_id or "",
content="Hello from the cloud",
)
)
mock_handle_chat_log.side_effect = fake_handle
result = await mock_cloud_ai_task_entity._async_generate_data(task, chat_log)
assert result.conversation_id == "conversation-id"
assert result.data == "Hello from the cloud"
async def test_generate_data_returns_json(
hass: HomeAssistant,
mock_cloud_ai_task_entity: CloudAITaskEntity,
mock_handle_chat_log: AsyncMock,
) -> None:
"""Test generating structured data."""
chat_log = conversation.ChatLog(hass, "conversation-id")
chat_log.async_add_user_content(conversation.UserContent(content="List names"))
task = ai_task.GenDataTask(
name="Task",
instructions="Return JSON",
structure=vol.Schema({vol.Required("names"): [str]}),
)
async def fake_handle(chat_type, log, task_name, structure):
log.async_add_assistant_content_without_tools(
conversation.AssistantContent(
agent_id=mock_cloud_ai_task_entity.entity_id or "",
content='{"names": ["A", "B"]}',
)
)
mock_handle_chat_log.side_effect = fake_handle
result = await mock_cloud_ai_task_entity._async_generate_data(task, chat_log)
assert result.data == {"names": ["A", "B"]}
async def test_generate_data_invalid_json(
hass: HomeAssistant,
mock_cloud_ai_task_entity: CloudAITaskEntity,
mock_handle_chat_log: AsyncMock,
) -> None:
"""Test invalid JSON responses raise an error."""
chat_log = conversation.ChatLog(hass, "conversation-id")
chat_log.async_add_user_content(conversation.UserContent(content="List names"))
task = ai_task.GenDataTask(
name="Task",
instructions="Return JSON",
structure=vol.Schema({vol.Required("names"): [str]}),
)
async def fake_handle(chat_type, log, task_name, structure):
log.async_add_assistant_content_without_tools(
conversation.AssistantContent(
agent_id=mock_cloud_ai_task_entity.entity_id or "",
content="not-json",
)
)
mock_handle_chat_log.side_effect = fake_handle
with pytest.raises(
HomeAssistantError, match="Error with OpenAI structured response"
):
await mock_cloud_ai_task_entity._async_generate_data(task, chat_log)
async def test_generate_image_no_attachments(
hass: HomeAssistant, mock_cloud_ai_task_entity: CloudAITaskEntity
) -> None:
"""Test generating an image without attachments."""
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.return_value = {
"mime_type": "image/png",
"image_data": b"IMG",
"model": "mock-image",
"width": 1024,
"height": 768,
"revised_prompt": "Improved prompt",
}
task = ai_task.GenImageTask(name="Task", instructions="Draw something")
chat_log = conversation.ChatLog(hass, "conversation-id")
result = await mock_cloud_ai_task_entity._async_generate_image(task, chat_log)
assert result.image_data == b"IMG"
assert result.mime_type == "image/png"
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.assert_awaited_once_with(
prompt="Draw something"
)
async def test_generate_image_with_attachments(
hass: HomeAssistant,
mock_cloud_ai_task_entity: CloudAITaskEntity,
mock_prepare_generation_attachments: AsyncMock,
) -> None:
"""Test generating an edited image when attachments are provided."""
mock_cloud_ai_task_entity._cloud.llm.async_edit_image.return_value = {
"mime_type": "image/png",
"image_data": b"IMG",
}
task = ai_task.GenImageTask(
name="Task",
instructions="Edit this",
attachments=[
conversation.Attachment(
media_content_id="media-source://media/snapshot.png",
mime_type="image/png",
path=hass.config.path("snapshot.png"),
)
],
)
chat_log = conversation.ChatLog(hass, "conversation-id")
prepared_attachments = [
LLMImageAttachment(filename="snapshot.png", mime_type="image/png", data=b"IMG")
]
mock_prepare_generation_attachments.return_value = prepared_attachments
await mock_cloud_ai_task_entity._async_generate_image(task, chat_log)
mock_cloud_ai_task_entity._cloud.llm.async_edit_image.assert_awaited_once_with(
prompt="Edit this",
attachments=prepared_attachments,
)
@pytest.mark.parametrize(
("err", "expected_exception", "message"),
[
(
LLMAuthenticationError("auth"),
HomeAssistantError,
"Cloud LLM authentication failed",
),
(
LLMRateLimitError("limit"),
HomeAssistantError,
"Cloud LLM is rate limited",
),
(
LLMResponseError("bad response"),
HomeAssistantError,
"bad response",
),
(
LLMServiceError("service"),
HomeAssistantError,
"Error talking to Cloud LLM",
),
(
LLMError("generic"),
HomeAssistantError,
"generic",
),
],
)
async def test_generate_image_error_handling(
hass: HomeAssistant,
mock_cloud_ai_task_entity: CloudAITaskEntity,
err: Exception,
expected_exception: type[Exception],
message: str,
) -> None:
"""Test image generation error handling."""
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.side_effect = err
task = ai_task.GenImageTask(name="Task", instructions="Draw something")
chat_log = conversation.ChatLog(hass, "conversation-id")
with pytest.raises(expected_exception, match=message):
await mock_cloud_ai_task_entity._async_generate_image(task, chat_log)