1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-24 21:06:19 +00:00

Add home assistant cloud conversation (#157090)

This commit is contained in:
victorigualada
2025-11-25 20:04:19 +01:00
committed by GitHub
parent d6fb268119
commit 7c2741bd36
6 changed files with 392 additions and 74 deletions

View File

@@ -80,6 +80,7 @@ DEFAULT_MODE = MODE_PROD
PLATFORMS = [
Platform.AI_TASK,
Platform.BINARY_SENSOR,
Platform.CONVERSATION,
Platform.STT,
Platform.TTS,
]

View File

@@ -92,6 +92,7 @@ DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update")
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
AI_TASK_ENTITY_UNIQUE_ID = "cloud-ai-task"
CONVERSATION_ENTITY_UNIQUE_ID = "cloud-conversation-agent"
LOGIN_MFA_TIMEOUT = 60

View File

@@ -0,0 +1,75 @@
"""Conversation support for Home Assistant Cloud."""
from __future__ import annotations
from typing import Literal
from hass_nabucasa.llm import LLMError
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from .const import CONVERSATION_ENTITY_UNIQUE_ID, DATA_CLOUD, DOMAIN
from .entity import BaseCloudLLMEntity
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up the Home Assistant Cloud conversation entity."""
cloud = hass.data[DATA_CLOUD]
try:
await cloud.llm.async_ensure_token()
except LLMError:
return
async_add_entities([CloudConversationEntity(cloud, config_entry)])
class CloudConversationEntity(
conversation.ConversationEntity,
BaseCloudLLMEntity,
):
"""Home Assistant Cloud conversation agent."""
_attr_has_entity_name = True
_attr_name = "Home Assistant Cloud"
_attr_translation_key = "cloud_conversation"
_attr_unique_id = CONVERSATION_ENTITY_UNIQUE_ID
_attr_supported_features = conversation.ConversationEntityFeature.CONTROL
@property
def available(self) -> bool:
"""Return if the entity is available."""
return self._cloud.is_logged_in and self._cloud.valid_subscription
@property
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
return MATCH_ALL
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,
chat_log: conversation.ChatLog,
) -> conversation.ConversationResult:
"""Process a user input."""
try:
await chat_log.async_provide_llm_data(
user_input.as_llm_context(DOMAIN),
llm.LLM_API_ASSIST,
None,
user_input.extra_system_prompt,
)
except conversation.ConverseError as err:
return err.as_conversation_result()
await self._async_handle_chat_log("conversation", chat_log)
return conversation.async_get_result_from_chat_log(user_input, chat_log)

View File

@@ -1,7 +1,7 @@
"""Helpers for cloud LLM chat handling."""
import base64
from collections.abc import AsyncGenerator, Callable
from collections.abc import AsyncGenerator, Callable, Iterable
from enum import Enum
import json
import logging
@@ -16,13 +16,22 @@ from hass_nabucasa.llm import (
LLMResponseError,
LLMServiceError,
)
from litellm import ResponseFunctionToolCall, ResponsesAPIStreamEvents
from litellm import (
ResponseFunctionToolCall,
ResponseInputParam,
ResponsesAPIStreamEvents,
)
from openai.types.responses import (
FunctionToolParam,
ResponseInputItemParam,
ResponseReasoningItem,
ToolParam,
WebSearchToolParam,
)
from openai.types.responses.response_input_param import (
ImageGenerationCall as ImageGenerationCallParam,
)
from openai.types.responses.response_output_item import ImageGenerationCall
import voluptuous as vol
from voluptuous_openapi import convert
@@ -50,34 +59,97 @@ class ResponseItemType(str, Enum):
IMAGE = "image"
def _convert_content_to_chat_message(
content: conversation.Content,
) -> dict[str, Any] | None:
"""Convert ChatLog content to a responses message."""
if content.role not in ("user", "system", "tool_result", "assistant"):
return None
def _convert_content_to_param(
chat_content: Iterable[conversation.Content],
) -> ResponseInputParam:
"""Convert any native chat message for this agent to the native format."""
messages: ResponseInputParam = []
reasoning_summary: list[str] = []
web_search_calls: dict[str, dict[str, Any]] = {}
text_content = cast(
conversation.SystemContent
| conversation.UserContent
| conversation.AssistantContent,
content,
)
for content in chat_content:
if isinstance(content, conversation.ToolResultContent):
if (
content.tool_name == "web_search_call"
and content.tool_call_id in web_search_calls
):
web_search_call = web_search_calls.pop(content.tool_call_id)
web_search_call["status"] = content.tool_result.get(
"status", "completed"
)
messages.append(cast("ResponseInputItemParam", web_search_call))
else:
messages.append(
{
"type": "function_call_output",
"call_id": content.tool_call_id,
"output": json.dumps(content.tool_result),
}
)
continue
if not text_content.content:
return None
if content.content:
role: Literal["user", "assistant", "system", "developer"] = content.role
if role == "system":
role = "developer"
messages.append(
{"type": "message", "role": role, "content": content.content}
)
content_type = "output_text" if text_content.role == "assistant" else "input_text"
if isinstance(content, conversation.AssistantContent):
if content.tool_calls:
for tool_call in content.tool_calls:
if (
tool_call.external
and tool_call.tool_name == "web_search_call"
and "action" in tool_call.tool_args
):
web_search_calls[tool_call.id] = {
"type": "web_search_call",
"id": tool_call.id,
"action": tool_call.tool_args["action"],
"status": "completed",
}
else:
messages.append(
{
"type": "function_call",
"name": tool_call.tool_name,
"arguments": json.dumps(tool_call.tool_args),
"call_id": tool_call.id,
}
)
return {
"role": text_content.role,
"content": [
{
"type": content_type,
"text": text_content.content,
}
],
}
if content.thinking_content:
reasoning_summary.append(content.thinking_content)
if isinstance(content.native, ResponseReasoningItem):
messages.append(
{
"type": "reasoning",
"id": content.native.id,
"summary": (
[
{
"type": "summary_text",
"text": summary,
}
for summary in reasoning_summary
]
if content.thinking_content
else []
),
"encrypted_content": content.native.encrypted_content,
}
)
reasoning_summary = []
elif isinstance(content.native, ImageGenerationCall):
messages.append(
cast(ImageGenerationCallParam, content.native.to_dict())
)
return messages
def _format_tool(
@@ -381,25 +453,16 @@ class BaseCloudLLMEntity(Entity):
async def _prepare_chat_for_generation(
self,
chat_log: conversation.ChatLog,
messages: ResponseInputParam,
response_format: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Prepare kwargs for Cloud LLM from the chat log."""
messages = [
message
for content in chat_log.content
if (message := _convert_content_to_chat_message(content))
]
if not messages or messages[-1]["role"] != "user":
raise HomeAssistantError("No user prompt found")
last_content = chat_log.content[-1]
last_content: Any = chat_log.content[-1]
if last_content.role == "user" and last_content.attachments:
files = await self._async_prepare_files_for_prompt(last_content.attachments)
user_message = messages[-1]
current_content = user_message.get("content", [])
user_message["content"] = [*(current_content or []), *files]
current_content = last_content.content
last_content = [*(current_content or []), *files]
tools: list[ToolParam] = []
tool_choice: str | None = None
@@ -503,8 +566,11 @@ class BaseCloudLLMEntity(Entity):
},
}
messages = _convert_content_to_param(chat_log.content)
response_kwargs = await self._prepare_chat_for_generation(
chat_log,
messages,
response_format,
)
@@ -518,15 +584,21 @@ class BaseCloudLLMEntity(Entity):
**response_kwargs,
)
async for _ in chat_log.async_add_delta_content_stream(
agent_id=self.entity_id,
stream=_transform_stream(
chat_log,
raw_stream,
True,
),
):
pass
messages.extend(
_convert_content_to_param(
[
content
async for content in chat_log.async_add_delta_content_stream(
self.entity_id,
_transform_stream(
chat_log,
raw_stream,
True,
),
)
]
)
)
except LLMAuthenticationError as err:
raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err

View File

@@ -0,0 +1,167 @@
"""Tests for the Home Assistant Cloud conversation entity."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
from hass_nabucasa.llm import LLMError
import pytest
from homeassistant.components import conversation
from homeassistant.components.cloud.const import DATA_CLOUD, DOMAIN
from homeassistant.components.cloud.conversation import (
CloudConversationEntity,
async_setup_entry,
)
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import intent, llm
from tests.common import MockConfigEntry
@pytest.fixture
def cloud_conversation_entity(hass: HomeAssistant) -> CloudConversationEntity:
"""Return a CloudConversationEntity attached to hass."""
cloud = MagicMock()
cloud.llm = MagicMock()
cloud.is_logged_in = True
cloud.valid_subscription = True
entry = MockConfigEntry(domain=DOMAIN)
entry.add_to_hass(hass)
entity = CloudConversationEntity(cloud, entry)
entity.entity_id = "conversation.home_assistant_cloud"
entity.hass = hass
return entity
def test_entity_availability(
cloud_conversation_entity: CloudConversationEntity,
) -> None:
"""Test that availability mirrors the cloud login/subscription state."""
cloud_conversation_entity._cloud.is_logged_in = True
cloud_conversation_entity._cloud.valid_subscription = True
assert cloud_conversation_entity.available
cloud_conversation_entity._cloud.is_logged_in = False
assert not cloud_conversation_entity.available
cloud_conversation_entity._cloud.is_logged_in = True
cloud_conversation_entity._cloud.valid_subscription = False
assert not cloud_conversation_entity.available
async def test_async_handle_message(
hass: HomeAssistant, cloud_conversation_entity: CloudConversationEntity
) -> None:
"""Test that messages are processed through the chat log helper."""
user_input = conversation.ConversationInput(
text="apaga test",
context=Context(),
conversation_id="conversation-id",
device_id="device-id",
satellite_id=None,
language="es",
agent_id=cloud_conversation_entity.entity_id or "",
extra_system_prompt="hazlo",
)
chat_log = conversation.ChatLog(hass, user_input.conversation_id)
chat_log.async_add_user_content(conversation.UserContent(content=user_input.text))
chat_log.async_provide_llm_data = AsyncMock()
async def fake_handle(chat_type, log):
"""Inject assistant output so the result can be built."""
assert chat_type == "conversation"
assert log is chat_log
log.async_add_assistant_content_without_tools(
conversation.AssistantContent(
agent_id=cloud_conversation_entity.entity_id or "",
content="hecho",
)
)
handle_chat_log = AsyncMock(side_effect=fake_handle)
with patch.object(
cloud_conversation_entity, "_async_handle_chat_log", handle_chat_log
):
result = await cloud_conversation_entity._async_handle_message(
user_input, chat_log
)
chat_log.async_provide_llm_data.assert_awaited_once_with(
user_input.as_llm_context(DOMAIN),
llm.LLM_API_ASSIST,
None,
user_input.extra_system_prompt,
)
handle_chat_log.assert_awaited_once_with("conversation", chat_log)
assert result.conversation_id == "conversation-id"
assert result.response.speech["plain"]["speech"] == "hecho"
async def test_async_handle_message_converse_error(
hass: HomeAssistant, cloud_conversation_entity: CloudConversationEntity
) -> None:
"""Test that ConverseError short-circuits message handling."""
user_input = conversation.ConversationInput(
text="hola",
context=Context(),
conversation_id="conversation-id",
device_id=None,
satellite_id=None,
language="es",
agent_id=cloud_conversation_entity.entity_id or "",
)
chat_log = conversation.ChatLog(hass, user_input.conversation_id)
error_response = intent.IntentResponse(language="es")
converse_error = conversation.ConverseError(
"failed", user_input.conversation_id or "", error_response
)
chat_log.async_provide_llm_data = AsyncMock(side_effect=converse_error)
with patch.object(
cloud_conversation_entity, "_async_handle_chat_log", AsyncMock()
) as handle_chat_log:
result = await cloud_conversation_entity._async_handle_message(
user_input, chat_log
)
handle_chat_log.assert_not_called()
assert result.response is error_response
assert result.conversation_id == user_input.conversation_id
async def test_async_setup_entry_adds_entity(hass: HomeAssistant) -> None:
"""Test the platform setup adds the conversation entity."""
cloud = MagicMock()
cloud.llm = MagicMock(async_ensure_token=AsyncMock())
cloud.is_logged_in = True
cloud.valid_subscription = True
hass.data[DATA_CLOUD] = cloud
entry = MockConfigEntry(domain=DOMAIN)
entry.add_to_hass(hass)
add_entities = MagicMock()
await async_setup_entry(hass, entry, add_entities)
cloud.llm.async_ensure_token.assert_awaited_once()
assert add_entities.call_count == 1
assert isinstance(add_entities.call_args[0][0][0], CloudConversationEntity)
async def test_async_setup_entry_llm_error(hass: HomeAssistant) -> None:
"""Test entity setup is aborted when ensuring the token fails."""
cloud = MagicMock()
cloud.llm = MagicMock(async_ensure_token=AsyncMock(side_effect=LLMError("fail")))
cloud.is_logged_in = True
cloud.valid_subscription = True
hass.data[DATA_CLOUD] = cloud
entry = MockConfigEntry(domain=DOMAIN)
entry.add_to_hass(hass)
add_entities = MagicMock()
await async_setup_entry(hass, entry, add_entities)
cloud.llm.async_ensure_token.assert_awaited_once()
add_entities.assert_not_called()

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import base64
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
from PIL import Image
@@ -14,6 +13,7 @@ import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.components.cloud.entity import (
BaseCloudLLMEntity,
_convert_content_to_param,
_format_structured_output,
)
from homeassistant.core import HomeAssistant
@@ -38,6 +38,19 @@ def cloud_entity(hass: HomeAssistant) -> BaseCloudLLMEntity:
return entity
@pytest.fixture
def mock_prepare_files_for_prompt(
cloud_entity: BaseCloudLLMEntity,
) -> AsyncMock:
"""Patch file preparation helper on the entity."""
with patch.object(
cloud_entity,
"_async_prepare_files_for_prompt",
AsyncMock(),
) as mock:
yield mock
class DummyTool(llm.Tool):
"""Simple tool used for schema conversion tests."""
@@ -162,12 +175,12 @@ async def test_prepare_chat_for_generation_appends_attachments(
attachment = conversation.Attachment(
media_content_id="media-source://media/doorbell.jpg",
mime_type="image/jpeg",
path=hass.config.path("doorbell.jpg"),
path=Path(hass.config.path("doorbell.jpg")),
)
chat_log.async_add_user_content(
conversation.UserContent(content="Describe the door", attachments=[attachment])
)
chat_log.llm_api = SimpleNamespace(
chat_log.llm_api = MagicMock(
tools=[DummyTool()],
custom_serializer=None,
)
@@ -175,8 +188,11 @@ async def test_prepare_chat_for_generation_appends_attachments(
files = [{"type": "input_image", "image_url": "data://img", "detail": "auto"}]
mock_prepare_files_for_prompt.return_value = files
messages = _convert_content_to_param(chat_log.content)
response = await cloud_entity._prepare_chat_for_generation(
chat_log, response_format={"type": "json"}
chat_log,
messages,
response_format={"type": "json"},
)
assert response["conversation_id"] == "conversation-id"
@@ -185,35 +201,21 @@ async def test_prepare_chat_for_generation_appends_attachments(
assert len(response["tools"]) == 2
assert response["tools"][0]["name"] == "do_something"
assert response["tools"][1]["type"] == "web_search"
user_message = response["messages"][-1]
assert user_message["content"][0] == {
"type": "input_text",
"text": "Describe the door",
}
assert user_message["content"][1:] == files
assert response["messages"] is messages
mock_prepare_files_for_prompt.assert_awaited_once_with([attachment])
async def test_prepare_chat_for_generation_requires_user_prompt(
async def test_prepare_chat_for_generation_passes_messages_through(
hass: HomeAssistant, cloud_entity: BaseCloudLLMEntity
) -> None:
"""Test that we fail fast when there is no user input to process."""
"""Test that prepared messages are forwarded unchanged."""
chat_log = conversation.ChatLog(hass, "conversation-id")
chat_log.async_add_assistant_content_without_tools(
conversation.AssistantContent(agent_id="agent", content="Ready")
)
messages = _convert_content_to_param(chat_log.content)
with pytest.raises(HomeAssistantError, match="No user prompt found"):
await cloud_entity._prepare_chat_for_generation(chat_log)
response = await cloud_entity._prepare_chat_for_generation(chat_log, messages)
@pytest.fixture
def mock_prepare_files_for_prompt(
cloud_entity: BaseCloudLLMEntity,
) -> AsyncMock:
"""Patch file preparation helper on the entity."""
with patch.object(
cloud_entity,
"_async_prepare_files_for_prompt",
AsyncMock(),
) as mock:
yield mock
assert response["messages"] == messages
assert response["conversation_id"] == "conversation-id"