1
0
mirror of https://github.com/home-assistant/core.git synced 2026-06-04 22:53:54 +01:00
Files
core/homeassistant/components/conversation/http.py
T
2026-05-14 16:30:32 -04:00

385 lines
12 KiB
Python

"""HTTP endpoints for conversation integration."""
from dataclasses import asdict
from typing import Any
from aiohttp import web
import voluptuous as vol
from homeassistant.components import http, websocket_api
from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.chat_session import async_get_chat_session
from homeassistant.util import language as language_util
from .agent_manager import (
agent_id_validator,
async_converse,
async_get_agent,
get_agent_manager,
)
from .chat_log import DATA_CHAT_LOGS, async_get_chat_log, async_subscribe_chat_logs
from .const import DATA_COMPONENT, ChatLogEventType
from .entity import ConversationEntity
from .models import ConversationInput
@callback
def async_setup(hass: HomeAssistant) -> None:
"""Set up the HTTP API for the conversation integration."""
hass.http.register_view(ConversationProcessView())
websocket_api.async_register_command(hass, websocket_process)
websocket_api.async_register_command(hass, websocket_prepare)
websocket_api.async_register_command(hass, websocket_list_agents)
websocket_api.async_register_command(hass, websocket_list_sentences)
websocket_api.async_register_command(hass, websocket_hass_agent_debug)
websocket_api.async_register_command(hass, websocket_hass_agent_language_scores)
websocket_api.async_register_command(hass, websocket_subscribe_chat_log)
websocket_api.async_register_command(hass, websocket_subscribe_chat_log_index)
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/process",
vol.Required("text"): str,
vol.Optional("conversation_id"): vol.Any(str, None),
vol.Optional("language"): str,
vol.Optional("agent_id"): agent_id_validator,
vol.Optional("device_id"): vol.Any(str, None),
vol.Optional("satellite_id"): vol.Any(str, None),
}
)
@websocket_api.async_response
async def websocket_process(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Process text."""
result = await async_converse(
hass=hass,
text=msg["text"],
conversation_id=msg.get("conversation_id"),
context=connection.context(msg),
language=msg.get("language"),
agent_id=msg.get("agent_id"),
device_id=msg.get("device_id"),
satellite_id=msg.get("satellite_id"),
)
connection.send_result(msg["id"], result.as_dict())
@websocket_api.websocket_command(
{
"type": "conversation/prepare",
vol.Optional("language"): str,
vol.Optional("agent_id"): agent_id_validator,
}
)
@websocket_api.async_response
async def websocket_prepare(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Reload intents."""
agent = async_get_agent(hass, msg.get("agent_id"))
if agent is None:
connection.send_error(msg["id"], websocket_api.ERR_NOT_FOUND, "Agent not found")
return
await agent.async_prepare(msg.get("language"))
connection.send_result(msg["id"])
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/agent/list",
vol.Optional("language"): str,
vol.Optional("country"): str,
}
)
@websocket_api.async_response
async def websocket_list_agents(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List conversation agents and, optionally, if they support a given language."""
country = msg.get("country")
language = msg.get("language")
agents = []
for entity in hass.data[DATA_COMPONENT].entities:
supported_languages = entity.supported_languages
if language and supported_languages != MATCH_ALL:
supported_languages = language_util.matches(
language, supported_languages, country
)
name = entity.entity_id
if state := hass.states.get(entity.entity_id):
name = state.name
agents.append(
{
"id": entity.entity_id,
"name": name,
"supported_languages": supported_languages,
}
)
manager = get_agent_manager(hass)
for agent_info in manager.async_get_agent_info():
agent = manager.async_get_agent(agent_info.id)
assert agent is not None
if isinstance(agent, ConversationEntity):
continue
supported_languages = agent.supported_languages
if language and supported_languages != MATCH_ALL:
supported_languages = language_util.matches(
language, supported_languages, country
)
agent_dict: dict[str, Any] = {
"id": agent_info.id,
"name": agent_info.name,
"supported_languages": supported_languages,
}
agents.append(agent_dict)
connection.send_message(websocket_api.result_message(msg["id"], {"agents": agents}))
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/sentences/list",
}
)
@websocket_api.require_admin
@websocket_api.async_response
async def websocket_list_sentences(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List custom registered sentences."""
manager = get_agent_manager(hass)
connection.send_result(msg["id"], {"trigger_sentences": manager.trigger_sentences})
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/agent/homeassistant/debug",
vol.Required("sentences"): [str],
vol.Optional("language"): str,
vol.Optional("device_id"): vol.Any(str, None),
}
)
@websocket_api.async_response
async def websocket_hass_agent_debug(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Return intents matched by the default agent for a list of sentences."""
agent = get_agent_manager(hass).default_agent
assert agent is not None
# Return results for each sentence in the same order as the input.
result_dicts: list[dict[str, Any] | None] = []
for sentence in msg["sentences"]:
user_input = ConversationInput(
text=sentence,
context=connection.context(msg),
conversation_id=None,
device_id=msg.get("device_id"),
satellite_id=None,
language=msg.get("language", hass.config.language),
agent_id=agent.entity_id,
)
result_dict = await agent.async_debug_recognize(user_input)
result_dicts.append(result_dict)
connection.send_result(msg["id"], {"results": result_dicts})
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/agent/homeassistant/language_scores",
vol.Optional("language"): str,
vol.Optional("country"): str,
}
)
@websocket_api.async_response
async def websocket_hass_agent_language_scores(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Get support scores per language."""
agent = get_agent_manager(hass).default_agent
assert agent is not None
language = msg.get("language", hass.config.language)
country = msg.get("country", hass.config.country)
scores = await agent.async_get_language_scores()
matching_langs = language_util.matches(language, scores.keys(), country=country)
preferred_lang = matching_langs[0] if matching_langs else language
result = {
"languages": {
lang_key: asdict(lang_scores) for lang_key, lang_scores in scores.items()
},
"preferred_language": preferred_lang,
}
connection.send_result(msg["id"], result)
class ConversationProcessView(http.HomeAssistantView):
"""View to process text."""
url = "/api/conversation/process"
name = "api:conversation:process"
@RequestDataValidator(
vol.Schema(
{
vol.Required("text"): str,
vol.Optional("conversation_id"): str,
vol.Optional("language"): str,
vol.Optional("agent_id"): agent_id_validator,
vol.Optional("device_id"): vol.Any(str, None),
vol.Optional("satellite_id"): vol.Any(str, None),
}
)
)
async def post(self, request: web.Request, data: dict[str, str]) -> web.Response:
"""Send a request for processing."""
hass = request.app[http.KEY_HASS]
result = await async_converse(
hass,
text=data["text"],
conversation_id=data.get("conversation_id"),
context=self.context(request),
language=data.get("language"),
agent_id=data.get("agent_id"),
device_id=data.get("device_id"),
satellite_id=data.get("satellite_id"),
)
return self.json(result.as_dict())
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/chat_log/subscribe",
vol.Required("conversation_id"): str,
}
)
@websocket_api.require_admin
def websocket_subscribe_chat_log(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Subscribe to a chat log."""
msg_id = msg["id"]
subscribed_conversation = msg["conversation_id"]
chat_logs = hass.data.get(DATA_CHAT_LOGS)
if not chat_logs or subscribed_conversation not in chat_logs:
connection.send_error(
msg_id,
websocket_api.ERR_NOT_FOUND,
"Conversation chat log not found",
)
return
@callback
def forward_events(conversation_id: str, event_type: str, data: dict) -> None:
"""Forward chat log events to websocket connection."""
if conversation_id != subscribed_conversation:
return
connection.send_event(
msg_id,
{
"conversation_id": conversation_id,
"event_type": event_type,
"data": data,
},
)
if event_type == ChatLogEventType.DELETED:
unsubscribe()
del connection.subscriptions[msg_id]
unsubscribe = async_subscribe_chat_logs(hass, forward_events)
connection.subscriptions[msg_id] = unsubscribe
connection.send_result(msg_id)
with (
async_get_chat_session(hass, subscribed_conversation) as session,
async_get_chat_log(hass, session) as chat_log,
):
connection.send_event(
msg_id,
{
"event_type": ChatLogEventType.INITIAL_STATE,
"data": chat_log.as_dict(),
},
)
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/chat_log/subscribe_index",
}
)
@websocket_api.require_admin
def websocket_subscribe_chat_log_index(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Subscribe to a chat log."""
msg_id = msg["id"]
@callback
def forward_events(
conversation_id: str, event_type: ChatLogEventType, data: dict
) -> None:
"""Forward chat log events to websocket connection."""
if event_type not in (ChatLogEventType.CREATED, ChatLogEventType.DELETED):
return
connection.send_event(
msg_id,
{
"conversation_id": conversation_id,
"event_type": event_type,
"data": data,
},
)
unsubscribe = async_subscribe_chat_logs(hass, forward_events)
connection.subscriptions[msg["id"]] = unsubscribe
connection.send_result(msg["id"])
chat_logs = hass.data.get(DATA_CHAT_LOGS)
if not chat_logs:
return
connection.send_event(
msg_id,
{
"event_type": ChatLogEventType.INITIAL_STATE,
"data": [c.as_dict() for c in chat_logs.values()],
},
)