mirror of
https://github.com/home-assistant/core.git
synced 2025-12-24 21:06:19 +00:00
Add script llm tool (#118936)
* Add script llm tool * Add tests * More tests * more test * more test * Add area and floor resolving * coverage * coverage * fix ColorTempSelector * fix mypy * fix mypy * add script reload test * Cache script tool parameters * Make custom_serializer a part of api --------- Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
@@ -8,6 +8,7 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.components.intent import async_register_timer_handler
|
||||
from homeassistant.components.script.config import ScriptConfig
|
||||
from homeassistant.core import Context, HomeAssistant, State
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import (
|
||||
@@ -18,6 +19,7 @@ from homeassistant.helpers import (
|
||||
floor_registry as fr,
|
||||
intent,
|
||||
llm,
|
||||
selector,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import yaml
|
||||
@@ -564,11 +566,6 @@ async def test_assist_api_prompt(
|
||||
"names": "Unnamed Device",
|
||||
"state": "unavailable",
|
||||
},
|
||||
"script.test_script": {
|
||||
"description": "This is a test script",
|
||||
"names": "test_script",
|
||||
"state": "off",
|
||||
},
|
||||
}
|
||||
exposed_entities_prompt = (
|
||||
"An overview of the areas and the devices in this smart home:\n"
|
||||
@@ -634,3 +631,323 @@ async def test_assist_api_prompt(
|
||||
{area_prompt}
|
||||
{exposed_entities_prompt}"""
|
||||
)
|
||||
|
||||
|
||||
async def test_script_tool(
|
||||
hass: HomeAssistant,
|
||||
area_registry: ar.AreaRegistry,
|
||||
floor_registry: fr.FloorRegistry,
|
||||
) -> None:
|
||||
"""Test ScriptTool for the assist API."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
context = Context()
|
||||
llm_context = llm.LLMContext(
|
||||
platform="test_platform",
|
||||
context=context,
|
||||
user_prompt="test_text",
|
||||
language="*",
|
||||
assistant="conversation",
|
||||
device_id=None,
|
||||
)
|
||||
|
||||
# Create a script with a unique ID
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"script",
|
||||
{
|
||||
"script": {
|
||||
"test_script": {
|
||||
"description": "This is a test script",
|
||||
"sequence": [],
|
||||
"fields": {
|
||||
"beer": {"description": "Number of beers", "required": True},
|
||||
"wine": {"selector": {"number": {"min": 0, "max": 3}}},
|
||||
"where": {"selector": {"area": {}}},
|
||||
"area_list": {"selector": {"area": {"multiple": True}}},
|
||||
"floor": {"selector": {"floor": {}}},
|
||||
"floor_list": {"selector": {"floor": {"multiple": True}}},
|
||||
"extra_field": {"selector": {"area": {}}},
|
||||
},
|
||||
},
|
||||
"unexposed_script": {
|
||||
"sequence": [],
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async_expose_entity(hass, "conversation", "script.test_script", True)
|
||||
|
||||
area = area_registry.async_create("Living room")
|
||||
floor = floor_registry.async_create("2")
|
||||
|
||||
assert llm.SCRIPT_PARAMETERS_CACHE not in hass.data
|
||||
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
|
||||
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
||||
assert len(tools) == 1
|
||||
|
||||
tool = tools[0]
|
||||
assert tool.name == "test_script"
|
||||
assert tool.description == "This is a test script"
|
||||
schema = {
|
||||
vol.Required("beer", description="Number of beers"): cv.string,
|
||||
vol.Optional("wine"): selector.NumberSelector({"min": 0, "max": 3}),
|
||||
vol.Optional("where"): selector.AreaSelector(),
|
||||
vol.Optional("area_list"): selector.AreaSelector({"multiple": True}),
|
||||
vol.Optional("floor"): selector.FloorSelector(),
|
||||
vol.Optional("floor_list"): selector.FloorSelector({"multiple": True}),
|
||||
vol.Optional("extra_field"): selector.AreaSelector(),
|
||||
}
|
||||
assert tool.parameters.schema == schema
|
||||
|
||||
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
|
||||
"test_script": ("This is a test script", vol.Schema(schema))
|
||||
}
|
||||
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name="test_script",
|
||||
tool_args={
|
||||
"beer": "3",
|
||||
"wine": 0,
|
||||
"where": "Living room",
|
||||
"area_list": ["Living room"],
|
||||
"floor": "2",
|
||||
"floor_list": ["2"],
|
||||
},
|
||||
)
|
||||
|
||||
with patch("homeassistant.core.ServiceRegistry.async_call") as mock_service_call:
|
||||
response = await api.async_call_tool(tool_input)
|
||||
|
||||
mock_service_call.assert_awaited_once_with(
|
||||
"script",
|
||||
"turn_on",
|
||||
{
|
||||
"entity_id": "script.test_script",
|
||||
"variables": {
|
||||
"beer": "3",
|
||||
"wine": 0,
|
||||
"where": area.id,
|
||||
"area_list": [area.id],
|
||||
"floor": floor.floor_id,
|
||||
"floor_list": [floor.floor_id],
|
||||
},
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
assert response == {"success": True}
|
||||
|
||||
# Test reload script with new parameters
|
||||
config = {
|
||||
"script": {
|
||||
"test_script": ScriptConfig(
|
||||
{
|
||||
"description": "This is a new test script",
|
||||
"sequence": [],
|
||||
"mode": "single",
|
||||
"max": 2,
|
||||
"max_exceeded": "WARNING",
|
||||
"trace": {},
|
||||
"fields": {
|
||||
"beer": {"description": "Number of beers", "required": True},
|
||||
},
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
with patch(
|
||||
"homeassistant.helpers.entity_component.EntityComponent.async_prepare_reload",
|
||||
return_value=config,
|
||||
):
|
||||
await hass.services.async_call("script", "reload", blocking=True)
|
||||
|
||||
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {}
|
||||
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
|
||||
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
||||
assert len(tools) == 1
|
||||
|
||||
tool = tools[0]
|
||||
assert tool.name == "test_script"
|
||||
assert tool.description == "This is a new test script"
|
||||
schema = {vol.Required("beer", description="Number of beers"): cv.string}
|
||||
assert tool.parameters.schema == schema
|
||||
|
||||
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
|
||||
"test_script": ("This is a new test script", vol.Schema(schema))
|
||||
}
|
||||
|
||||
|
||||
async def test_selector_serializer(
|
||||
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||
) -> None:
|
||||
"""Test serialization of Selectors in Open API format."""
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
selector_serializer = api.custom_serializer
|
||||
|
||||
assert selector_serializer(selector.ActionSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.AddonSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.AreaSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.AreaSelector({"multiple": True})) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
assert selector_serializer(selector.AssistPipelineSelector()) == {"type": "string"}
|
||||
assert selector_serializer(
|
||||
selector.AttributeSelector({"entity_id": "sensor.test"})
|
||||
) == {"type": "string"}
|
||||
assert selector_serializer(selector.BackupLocationSelector()) == {
|
||||
"type": "string",
|
||||
"pattern": "^(?:\\/backup|\\w+)$",
|
||||
}
|
||||
assert selector_serializer(selector.BooleanSelector()) == {"type": "boolean"}
|
||||
assert selector_serializer(selector.ColorRGBSelector()) == {
|
||||
"type": "array",
|
||||
"items": {"type": "number"},
|
||||
"maxItems": 3,
|
||||
"minItems": 3,
|
||||
"format": "RGB",
|
||||
}
|
||||
assert selector_serializer(selector.ColorTempSelector()) == {"type": "number"}
|
||||
assert selector_serializer(selector.ColorTempSelector({"min": 0, "max": 1000})) == {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1000,
|
||||
}
|
||||
assert selector_serializer(
|
||||
selector.ColorTempSelector({"min_mireds": 100, "max_mireds": 1000})
|
||||
) == {"type": "number", "minimum": 100, "maximum": 1000}
|
||||
assert selector_serializer(selector.ConfigEntrySelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.ConstantSelector({"value": "test"})) == {
|
||||
"enum": ["test"]
|
||||
}
|
||||
assert selector_serializer(selector.ConstantSelector({"value": 1})) == {"enum": [1]}
|
||||
assert selector_serializer(selector.ConstantSelector({"value": True})) == {
|
||||
"enum": [True]
|
||||
}
|
||||
assert selector_serializer(selector.QrCodeSelector({"data": "test"})) == {
|
||||
"type": "string"
|
||||
}
|
||||
assert selector_serializer(selector.ConversationAgentSelector()) == {
|
||||
"type": "string"
|
||||
}
|
||||
assert selector_serializer(selector.CountrySelector()) == {
|
||||
"type": "string",
|
||||
"format": "ISO 3166-1 alpha-2",
|
||||
}
|
||||
assert selector_serializer(
|
||||
selector.CountrySelector({"countries": ["GB", "FR"]})
|
||||
) == {"type": "string", "enum": ["GB", "FR"]}
|
||||
assert selector_serializer(selector.DateSelector()) == {
|
||||
"type": "string",
|
||||
"format": "date",
|
||||
}
|
||||
assert selector_serializer(selector.DateTimeSelector()) == {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
}
|
||||
assert selector_serializer(selector.DeviceSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.DeviceSelector({"multiple": True})) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
assert selector_serializer(selector.EntitySelector()) == {
|
||||
"type": "string",
|
||||
"format": "entity_id",
|
||||
}
|
||||
assert selector_serializer(selector.EntitySelector({"multiple": True})) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": "entity_id"},
|
||||
}
|
||||
assert selector_serializer(selector.FloorSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.FloorSelector({"multiple": True})) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
assert selector_serializer(selector.IconSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.LabelSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.LabelSelector({"multiple": True})) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
assert selector_serializer(selector.LanguageSelector()) == {
|
||||
"type": "string",
|
||||
"format": "RFC 5646",
|
||||
}
|
||||
assert selector_serializer(
|
||||
selector.LanguageSelector({"languages": ["en", "fr"]})
|
||||
) == {"type": "string", "enum": ["en", "fr"]}
|
||||
assert selector_serializer(selector.LocationSelector()) == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {"type": "number"},
|
||||
"longitude": {"type": "number"},
|
||||
"radius": {"type": "number"},
|
||||
},
|
||||
"required": ["latitude", "longitude"],
|
||||
}
|
||||
assert selector_serializer(selector.MediaSelector()) == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity_id": {"type": "string"},
|
||||
"media_content_id": {"type": "string"},
|
||||
"media_content_type": {"type": "string"},
|
||||
"metadata": {"type": "object", "additionalProperties": True},
|
||||
},
|
||||
"required": ["entity_id", "media_content_id", "media_content_type"],
|
||||
}
|
||||
assert selector_serializer(selector.NumberSelector({"mode": "box"})) == {
|
||||
"type": "number"
|
||||
}
|
||||
assert selector_serializer(selector.NumberSelector({"min": 30, "max": 100})) == {
|
||||
"type": "number",
|
||||
"minimum": 30,
|
||||
"maximum": 100,
|
||||
}
|
||||
assert selector_serializer(selector.ObjectSelector()) == {"type": "object"}
|
||||
assert selector_serializer(
|
||||
selector.SelectSelector(
|
||||
{
|
||||
"options": [
|
||||
{"value": "A", "label": "Letter A"},
|
||||
{"value": "B", "label": "Letter B"},
|
||||
{"value": "C", "label": "Letter C"},
|
||||
]
|
||||
}
|
||||
)
|
||||
) == {"type": "string", "enum": ["A", "B", "C"]}
|
||||
assert selector_serializer(
|
||||
selector.SelectSelector({"options": ["A", "B", "C"], "multiple": True})
|
||||
) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "enum": ["A", "B", "C"]},
|
||||
"uniqueItems": True,
|
||||
}
|
||||
assert selector_serializer(
|
||||
selector.StateSelector({"entity_id": "sensor.test"})
|
||||
) == {"type": "string"}
|
||||
assert selector_serializer(selector.TemplateSelector()) == {
|
||||
"type": "string",
|
||||
"format": "jinja2",
|
||||
}
|
||||
assert selector_serializer(selector.TextSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.TextSelector({"multiple": True})) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
assert selector_serializer(selector.ThemeSelector()) == {"type": "string"}
|
||||
assert selector_serializer(selector.TimeSelector()) == {
|
||||
"type": "string",
|
||||
"format": "time",
|
||||
}
|
||||
assert selector_serializer(selector.TriggerSelector()) == {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
assert selector_serializer(selector.FileSelector({"accept": ".txt"})) == {
|
||||
"type": "string"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user