1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-24 21:06:19 +00:00

Add script llm tool (#118936)

* Add script llm tool

* Add tests

* More tests

* more test

* more test

* Add area and floor resolving

* coverage

* coverage

* fix ColorTempSelector

* fix mypy

* fix mypy

* add script reload test

* Cache script tool parameters

* Make custom_serializer a part of api

---------

Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
Denis Shulyaka
2024-06-25 18:43:26 +03:00
committed by GitHub
parent 77fea8a73e
commit 2386ed3830
14 changed files with 639 additions and 55 deletions

View File

@@ -8,6 +8,7 @@ import voluptuous as vol
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.components.intent import async_register_timer_handler
from homeassistant.components.script.config import ScriptConfig
from homeassistant.core import Context, HomeAssistant, State
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
@@ -18,6 +19,7 @@ from homeassistant.helpers import (
floor_registry as fr,
intent,
llm,
selector,
)
from homeassistant.setup import async_setup_component
from homeassistant.util import yaml
@@ -564,11 +566,6 @@ async def test_assist_api_prompt(
"names": "Unnamed Device",
"state": "unavailable",
},
"script.test_script": {
"description": "This is a test script",
"names": "test_script",
"state": "off",
},
}
exposed_entities_prompt = (
"An overview of the areas and the devices in this smart home:\n"
@@ -634,3 +631,323 @@ async def test_assist_api_prompt(
{area_prompt}
{exposed_entities_prompt}"""
)
async def test_script_tool(
hass: HomeAssistant,
area_registry: ar.AreaRegistry,
floor_registry: fr.FloorRegistry,
) -> None:
"""Test ScriptTool for the assist API."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "intent", {})
context = Context()
llm_context = llm.LLMContext(
platform="test_platform",
context=context,
user_prompt="test_text",
language="*",
assistant="conversation",
device_id=None,
)
# Create a script with a unique ID
assert await async_setup_component(
hass,
"script",
{
"script": {
"test_script": {
"description": "This is a test script",
"sequence": [],
"fields": {
"beer": {"description": "Number of beers", "required": True},
"wine": {"selector": {"number": {"min": 0, "max": 3}}},
"where": {"selector": {"area": {}}},
"area_list": {"selector": {"area": {"multiple": True}}},
"floor": {"selector": {"floor": {}}},
"floor_list": {"selector": {"floor": {"multiple": True}}},
"extra_field": {"selector": {"area": {}}},
},
},
"unexposed_script": {
"sequence": [],
},
}
},
)
async_expose_entity(hass, "conversation", "script.test_script", True)
area = area_registry.async_create("Living room")
floor = floor_registry.async_create("2")
assert llm.SCRIPT_PARAMETERS_CACHE not in hass.data
api = await llm.async_get_api(hass, "assist", llm_context)
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
assert len(tools) == 1
tool = tools[0]
assert tool.name == "test_script"
assert tool.description == "This is a test script"
schema = {
vol.Required("beer", description="Number of beers"): cv.string,
vol.Optional("wine"): selector.NumberSelector({"min": 0, "max": 3}),
vol.Optional("where"): selector.AreaSelector(),
vol.Optional("area_list"): selector.AreaSelector({"multiple": True}),
vol.Optional("floor"): selector.FloorSelector(),
vol.Optional("floor_list"): selector.FloorSelector({"multiple": True}),
vol.Optional("extra_field"): selector.AreaSelector(),
}
assert tool.parameters.schema == schema
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
"test_script": ("This is a test script", vol.Schema(schema))
}
tool_input = llm.ToolInput(
tool_name="test_script",
tool_args={
"beer": "3",
"wine": 0,
"where": "Living room",
"area_list": ["Living room"],
"floor": "2",
"floor_list": ["2"],
},
)
with patch("homeassistant.core.ServiceRegistry.async_call") as mock_service_call:
response = await api.async_call_tool(tool_input)
mock_service_call.assert_awaited_once_with(
"script",
"turn_on",
{
"entity_id": "script.test_script",
"variables": {
"beer": "3",
"wine": 0,
"where": area.id,
"area_list": [area.id],
"floor": floor.floor_id,
"floor_list": [floor.floor_id],
},
},
context=context,
)
assert response == {"success": True}
# Test reload script with new parameters
config = {
"script": {
"test_script": ScriptConfig(
{
"description": "This is a new test script",
"sequence": [],
"mode": "single",
"max": 2,
"max_exceeded": "WARNING",
"trace": {},
"fields": {
"beer": {"description": "Number of beers", "required": True},
},
}
)
}
}
with patch(
"homeassistant.helpers.entity_component.EntityComponent.async_prepare_reload",
return_value=config,
):
await hass.services.async_call("script", "reload", blocking=True)
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {}
api = await llm.async_get_api(hass, "assist", llm_context)
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
assert len(tools) == 1
tool = tools[0]
assert tool.name == "test_script"
assert tool.description == "This is a new test script"
schema = {vol.Required("beer", description="Number of beers"): cv.string}
assert tool.parameters.schema == schema
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
"test_script": ("This is a new test script", vol.Schema(schema))
}
async def test_selector_serializer(
hass: HomeAssistant, llm_context: llm.LLMContext
) -> None:
"""Test serialization of Selectors in Open API format."""
api = await llm.async_get_api(hass, "assist", llm_context)
selector_serializer = api.custom_serializer
assert selector_serializer(selector.ActionSelector()) == {"type": "string"}
assert selector_serializer(selector.AddonSelector()) == {"type": "string"}
assert selector_serializer(selector.AreaSelector()) == {"type": "string"}
assert selector_serializer(selector.AreaSelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.AssistPipelineSelector()) == {"type": "string"}
assert selector_serializer(
selector.AttributeSelector({"entity_id": "sensor.test"})
) == {"type": "string"}
assert selector_serializer(selector.BackupLocationSelector()) == {
"type": "string",
"pattern": "^(?:\\/backup|\\w+)$",
}
assert selector_serializer(selector.BooleanSelector()) == {"type": "boolean"}
assert selector_serializer(selector.ColorRGBSelector()) == {
"type": "array",
"items": {"type": "number"},
"maxItems": 3,
"minItems": 3,
"format": "RGB",
}
assert selector_serializer(selector.ColorTempSelector()) == {"type": "number"}
assert selector_serializer(selector.ColorTempSelector({"min": 0, "max": 1000})) == {
"type": "number",
"minimum": 0,
"maximum": 1000,
}
assert selector_serializer(
selector.ColorTempSelector({"min_mireds": 100, "max_mireds": 1000})
) == {"type": "number", "minimum": 100, "maximum": 1000}
assert selector_serializer(selector.ConfigEntrySelector()) == {"type": "string"}
assert selector_serializer(selector.ConstantSelector({"value": "test"})) == {
"enum": ["test"]
}
assert selector_serializer(selector.ConstantSelector({"value": 1})) == {"enum": [1]}
assert selector_serializer(selector.ConstantSelector({"value": True})) == {
"enum": [True]
}
assert selector_serializer(selector.QrCodeSelector({"data": "test"})) == {
"type": "string"
}
assert selector_serializer(selector.ConversationAgentSelector()) == {
"type": "string"
}
assert selector_serializer(selector.CountrySelector()) == {
"type": "string",
"format": "ISO 3166-1 alpha-2",
}
assert selector_serializer(
selector.CountrySelector({"countries": ["GB", "FR"]})
) == {"type": "string", "enum": ["GB", "FR"]}
assert selector_serializer(selector.DateSelector()) == {
"type": "string",
"format": "date",
}
assert selector_serializer(selector.DateTimeSelector()) == {
"type": "string",
"format": "date-time",
}
assert selector_serializer(selector.DeviceSelector()) == {"type": "string"}
assert selector_serializer(selector.DeviceSelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.EntitySelector()) == {
"type": "string",
"format": "entity_id",
}
assert selector_serializer(selector.EntitySelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string", "format": "entity_id"},
}
assert selector_serializer(selector.FloorSelector()) == {"type": "string"}
assert selector_serializer(selector.FloorSelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.IconSelector()) == {"type": "string"}
assert selector_serializer(selector.LabelSelector()) == {"type": "string"}
assert selector_serializer(selector.LabelSelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.LanguageSelector()) == {
"type": "string",
"format": "RFC 5646",
}
assert selector_serializer(
selector.LanguageSelector({"languages": ["en", "fr"]})
) == {"type": "string", "enum": ["en", "fr"]}
assert selector_serializer(selector.LocationSelector()) == {
"type": "object",
"properties": {
"latitude": {"type": "number"},
"longitude": {"type": "number"},
"radius": {"type": "number"},
},
"required": ["latitude", "longitude"],
}
assert selector_serializer(selector.MediaSelector()) == {
"type": "object",
"properties": {
"entity_id": {"type": "string"},
"media_content_id": {"type": "string"},
"media_content_type": {"type": "string"},
"metadata": {"type": "object", "additionalProperties": True},
},
"required": ["entity_id", "media_content_id", "media_content_type"],
}
assert selector_serializer(selector.NumberSelector({"mode": "box"})) == {
"type": "number"
}
assert selector_serializer(selector.NumberSelector({"min": 30, "max": 100})) == {
"type": "number",
"minimum": 30,
"maximum": 100,
}
assert selector_serializer(selector.ObjectSelector()) == {"type": "object"}
assert selector_serializer(
selector.SelectSelector(
{
"options": [
{"value": "A", "label": "Letter A"},
{"value": "B", "label": "Letter B"},
{"value": "C", "label": "Letter C"},
]
}
)
) == {"type": "string", "enum": ["A", "B", "C"]}
assert selector_serializer(
selector.SelectSelector({"options": ["A", "B", "C"], "multiple": True})
) == {
"type": "array",
"items": {"type": "string", "enum": ["A", "B", "C"]},
"uniqueItems": True,
}
assert selector_serializer(
selector.StateSelector({"entity_id": "sensor.test"})
) == {"type": "string"}
assert selector_serializer(selector.TemplateSelector()) == {
"type": "string",
"format": "jinja2",
}
assert selector_serializer(selector.TextSelector()) == {"type": "string"}
assert selector_serializer(selector.TextSelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.ThemeSelector()) == {"type": "string"}
assert selector_serializer(selector.TimeSelector()) == {
"type": "string",
"format": "time",
}
assert selector_serializer(selector.TriggerSelector()) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.FileSelector({"accept": ".txt"})) == {
"type": "string"
}