mirror of
https://github.com/home-assistant/core.git
synced 2026-04-22 17:59:02 +01:00
Add HomeAssistant Cloud ai_task (#157015)
This commit is contained in:
343
tests/components/cloud/test_ai_task.py
Normal file
343
tests/components/cloud/test_ai_task.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""Tests for the Home Assistant Cloud AI Task entity."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from hass_nabucasa.llm import (
|
||||
LLMAuthenticationError,
|
||||
LLMError,
|
||||
LLMImageAttachment,
|
||||
LLMRateLimitError,
|
||||
LLMResponseError,
|
||||
LLMServiceError,
|
||||
)
|
||||
from PIL import Image
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.components.cloud.ai_task import (
|
||||
CloudLLMTaskEntity,
|
||||
async_prepare_image_generation_attachments,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, HomeAssistantError
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cloud_ai_task_entity(hass: HomeAssistant) -> CloudLLMTaskEntity:
|
||||
"""Return a CloudLLMTaskEntity with a mocked cloud LLM."""
|
||||
cloud = MagicMock()
|
||||
cloud.llm = MagicMock(
|
||||
async_generate_image=AsyncMock(),
|
||||
async_edit_image=AsyncMock(),
|
||||
)
|
||||
cloud.is_logged_in = True
|
||||
cloud.valid_subscription = True
|
||||
entry = MockConfigEntry(domain="cloud")
|
||||
entry.add_to_hass(hass)
|
||||
entity = CloudLLMTaskEntity(cloud, entry)
|
||||
entity.entity_id = "ai_task.cloud_ai_task"
|
||||
entity.hass = hass
|
||||
return entity
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_handle_chat_log")
|
||||
def mock_handle_chat_log_fixture() -> AsyncMock:
|
||||
"""Patch the chat log handler."""
|
||||
with patch(
|
||||
"homeassistant.components.cloud.ai_task.CloudLLMTaskEntity._async_handle_chat_log",
|
||||
AsyncMock(),
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_prepare_generation_attachments")
|
||||
def mock_prepare_generation_attachments_fixture() -> AsyncMock:
|
||||
"""Patch image generation attachment preparation."""
|
||||
with patch(
|
||||
"homeassistant.components.cloud.ai_task.async_prepare_image_generation_attachments",
|
||||
AsyncMock(),
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
async def test_prepare_image_generation_attachments(
|
||||
hass: HomeAssistant, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test preparing attachments for image generation."""
|
||||
image_path = tmp_path / "snapshot.jpg"
|
||||
Image.new("RGB", (2, 2), "red").save(image_path, "JPEG")
|
||||
|
||||
attachments = [
|
||||
conversation.Attachment(
|
||||
media_content_id="media-source://media/snapshot.jpg",
|
||||
mime_type="image/jpeg",
|
||||
path=image_path,
|
||||
)
|
||||
]
|
||||
|
||||
result = await async_prepare_image_generation_attachments(hass, attachments)
|
||||
|
||||
assert len(result) == 1
|
||||
attachment = result[0]
|
||||
assert attachment["filename"] == "snapshot.jpg"
|
||||
assert attachment["mime_type"] == "image/png"
|
||||
assert attachment["data"].startswith(b"\x89PNG")
|
||||
|
||||
|
||||
async def test_prepare_image_generation_attachments_only_images(
|
||||
hass: HomeAssistant, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test non image attachments are rejected."""
|
||||
doc_path = tmp_path / "context.txt"
|
||||
doc_path.write_text("context")
|
||||
|
||||
attachments = [
|
||||
conversation.Attachment(
|
||||
media_content_id="media-source://media/context.txt",
|
||||
mime_type="text/plain",
|
||||
path=doc_path,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match="Only image attachments are supported for image generation",
|
||||
):
|
||||
await async_prepare_image_generation_attachments(hass, attachments)
|
||||
|
||||
|
||||
async def test_prepare_image_generation_attachments_missing_file(
|
||||
hass: HomeAssistant, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test missing attachments raise a helpful error."""
|
||||
missing_path = tmp_path / "missing.png"
|
||||
|
||||
attachments = [
|
||||
conversation.Attachment(
|
||||
media_content_id="media-source://media/missing.png",
|
||||
mime_type="image/png",
|
||||
path=missing_path,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(HomeAssistantError, match="`.*missing.png` does not exist"):
|
||||
await async_prepare_image_generation_attachments(hass, attachments)
|
||||
|
||||
|
||||
async def test_prepare_image_generation_attachments_processing_error(
|
||||
hass: HomeAssistant, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test invalid image data raises a processing error."""
|
||||
broken_path = tmp_path / "broken.png"
|
||||
broken_path.write_bytes(b"not-an-image")
|
||||
|
||||
attachments = [
|
||||
conversation.Attachment(
|
||||
media_content_id="media-source://media/broken.png",
|
||||
mime_type="image/png",
|
||||
path=broken_path,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match="Failed to process image attachment",
|
||||
):
|
||||
await async_prepare_image_generation_attachments(hass, attachments)
|
||||
|
||||
|
||||
async def test_generate_data_returns_text(
|
||||
hass: HomeAssistant,
|
||||
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
|
||||
mock_handle_chat_log: AsyncMock,
|
||||
) -> None:
|
||||
"""Test generating plain text data."""
|
||||
chat_log = conversation.ChatLog(hass, "conversation-id")
|
||||
chat_log.async_add_user_content(
|
||||
conversation.UserContent(content="Tell me something")
|
||||
)
|
||||
task = ai_task.GenDataTask(name="Task", instructions="Say hi")
|
||||
|
||||
async def fake_handle(chat_type, log, task_name, structure):
|
||||
"""Inject assistant output."""
|
||||
assert chat_type == "ai_task"
|
||||
log.async_add_assistant_content_without_tools(
|
||||
conversation.AssistantContent(
|
||||
agent_id=mock_cloud_ai_task_entity.entity_id or "",
|
||||
content="Hello from the cloud",
|
||||
)
|
||||
)
|
||||
|
||||
mock_handle_chat_log.side_effect = fake_handle
|
||||
result = await mock_cloud_ai_task_entity._async_generate_data(task, chat_log)
|
||||
|
||||
assert result.conversation_id == "conversation-id"
|
||||
assert result.data == "Hello from the cloud"
|
||||
|
||||
|
||||
async def test_generate_data_returns_json(
|
||||
hass: HomeAssistant,
|
||||
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
|
||||
mock_handle_chat_log: AsyncMock,
|
||||
) -> None:
|
||||
"""Test generating structured data."""
|
||||
chat_log = conversation.ChatLog(hass, "conversation-id")
|
||||
chat_log.async_add_user_content(conversation.UserContent(content="List names"))
|
||||
task = ai_task.GenDataTask(
|
||||
name="Task",
|
||||
instructions="Return JSON",
|
||||
structure=vol.Schema({vol.Required("names"): [str]}),
|
||||
)
|
||||
|
||||
async def fake_handle(chat_type, log, task_name, structure):
|
||||
log.async_add_assistant_content_without_tools(
|
||||
conversation.AssistantContent(
|
||||
agent_id=mock_cloud_ai_task_entity.entity_id or "",
|
||||
content='{"names": ["A", "B"]}',
|
||||
)
|
||||
)
|
||||
|
||||
mock_handle_chat_log.side_effect = fake_handle
|
||||
result = await mock_cloud_ai_task_entity._async_generate_data(task, chat_log)
|
||||
|
||||
assert result.data == {"names": ["A", "B"]}
|
||||
|
||||
|
||||
async def test_generate_data_invalid_json(
|
||||
hass: HomeAssistant,
|
||||
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
|
||||
mock_handle_chat_log: AsyncMock,
|
||||
) -> None:
|
||||
"""Test invalid JSON responses raise an error."""
|
||||
chat_log = conversation.ChatLog(hass, "conversation-id")
|
||||
chat_log.async_add_user_content(conversation.UserContent(content="List names"))
|
||||
task = ai_task.GenDataTask(
|
||||
name="Task",
|
||||
instructions="Return JSON",
|
||||
structure=vol.Schema({vol.Required("names"): [str]}),
|
||||
)
|
||||
|
||||
async def fake_handle(chat_type, log, task_name, structure):
|
||||
log.async_add_assistant_content_without_tools(
|
||||
conversation.AssistantContent(
|
||||
agent_id=mock_cloud_ai_task_entity.entity_id or "",
|
||||
content="not-json",
|
||||
)
|
||||
)
|
||||
|
||||
mock_handle_chat_log.side_effect = fake_handle
|
||||
with pytest.raises(
|
||||
HomeAssistantError, match="Error with OpenAI structured response"
|
||||
):
|
||||
await mock_cloud_ai_task_entity._async_generate_data(task, chat_log)
|
||||
|
||||
|
||||
async def test_generate_image_no_attachments(
|
||||
hass: HomeAssistant, mock_cloud_ai_task_entity: CloudLLMTaskEntity
|
||||
) -> None:
|
||||
"""Test generating an image without attachments."""
|
||||
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.return_value = {
|
||||
"mime_type": "image/png",
|
||||
"image_data": b"IMG",
|
||||
"model": "mock-image",
|
||||
"width": 1024,
|
||||
"height": 768,
|
||||
"revised_prompt": "Improved prompt",
|
||||
}
|
||||
task = ai_task.GenImageTask(name="Task", instructions="Draw something")
|
||||
chat_log = conversation.ChatLog(hass, "conversation-id")
|
||||
|
||||
result = await mock_cloud_ai_task_entity._async_generate_image(task, chat_log)
|
||||
|
||||
assert result.image_data == b"IMG"
|
||||
assert result.mime_type == "image/png"
|
||||
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.assert_awaited_once_with(
|
||||
prompt="Draw something"
|
||||
)
|
||||
|
||||
|
||||
async def test_generate_image_with_attachments(
|
||||
hass: HomeAssistant,
|
||||
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
|
||||
mock_prepare_generation_attachments: AsyncMock,
|
||||
) -> None:
|
||||
"""Test generating an edited image when attachments are provided."""
|
||||
mock_cloud_ai_task_entity._cloud.llm.async_edit_image.return_value = {
|
||||
"mime_type": "image/png",
|
||||
"image_data": b"IMG",
|
||||
}
|
||||
task = ai_task.GenImageTask(
|
||||
name="Task",
|
||||
instructions="Edit this",
|
||||
attachments=[
|
||||
conversation.Attachment(
|
||||
media_content_id="media-source://media/snapshot.png",
|
||||
mime_type="image/png",
|
||||
path=hass.config.path("snapshot.png"),
|
||||
)
|
||||
],
|
||||
)
|
||||
chat_log = conversation.ChatLog(hass, "conversation-id")
|
||||
prepared_attachments = [
|
||||
LLMImageAttachment(filename="snapshot.png", mime_type="image/png", data=b"IMG")
|
||||
]
|
||||
|
||||
mock_prepare_generation_attachments.return_value = prepared_attachments
|
||||
await mock_cloud_ai_task_entity._async_generate_image(task, chat_log)
|
||||
|
||||
mock_cloud_ai_task_entity._cloud.llm.async_edit_image.assert_awaited_once_with(
|
||||
prompt="Edit this",
|
||||
attachments=prepared_attachments,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("err", "expected_exception", "message"),
|
||||
[
|
||||
(
|
||||
LLMAuthenticationError("auth"),
|
||||
ConfigEntryAuthFailed,
|
||||
"Cloud LLM authentication failed",
|
||||
),
|
||||
(
|
||||
LLMRateLimitError("limit"),
|
||||
HomeAssistantError,
|
||||
"Cloud LLM is rate limited",
|
||||
),
|
||||
(
|
||||
LLMResponseError("bad response"),
|
||||
HomeAssistantError,
|
||||
"bad response",
|
||||
),
|
||||
(
|
||||
LLMServiceError("service"),
|
||||
HomeAssistantError,
|
||||
"Error talking to Cloud LLM",
|
||||
),
|
||||
(
|
||||
LLMError("generic"),
|
||||
HomeAssistantError,
|
||||
"generic",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_generate_image_error_handling(
|
||||
hass: HomeAssistant,
|
||||
mock_cloud_ai_task_entity: CloudLLMTaskEntity,
|
||||
err: Exception,
|
||||
expected_exception: type[Exception],
|
||||
message: str,
|
||||
) -> None:
|
||||
"""Test image generation error handling."""
|
||||
mock_cloud_ai_task_entity._cloud.llm.async_generate_image.side_effect = err
|
||||
task = ai_task.GenImageTask(name="Task", instructions="Draw something")
|
||||
chat_log = conversation.ChatLog(hass, "conversation-id")
|
||||
|
||||
with pytest.raises(expected_exception, match=message):
|
||||
await mock_cloud_ai_task_entity._async_generate_image(task, chat_log)
|
||||
Reference in New Issue
Block a user