1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-24 21:06:19 +00:00

Improve LLM tool quality by more clearly specifying device_class slots (#122723)

* Limit intent / llm API device_class slots to only necessary services and limited set of values

* Fix ruff errors

* Run ruff format

* Fix typing and improve output schema

* Fix schema and improve flattening

* Revert conftest

* Revert recorder

* Fix ruff format errors

* Update using latest version of voluptuous
This commit is contained in:
Allen Porter
2024-07-31 05:36:02 -07:00
committed by GitHub
parent 7c7b408df1
commit f14471112d
7 changed files with 183 additions and 21 deletions

View File

@@ -7,7 +7,7 @@ import asyncio
from collections.abc import Callable, Collection, Coroutine, Iterable
import dataclasses
from dataclasses import dataclass, field
from enum import Enum, auto
from enum import Enum, StrEnum, auto
from functools import cached_property
from itertools import groupby
import logging
@@ -820,6 +820,7 @@ class DynamicServiceIntentHandler(IntentHandler):
required_states: set[str] | None = None,
description: str | None = None,
platforms: set[str] | None = None,
device_classes: set[type[StrEnum]] | None = None,
) -> None:
"""Create Service Intent Handler."""
self.intent_type = intent_type
@@ -829,6 +830,7 @@ class DynamicServiceIntentHandler(IntentHandler):
self.required_states = required_states
self.description = description
self.platforms = platforms
self.device_classes = device_classes
self.required_slots: _IntentSlotsType = {}
if required_slots:
@@ -851,13 +853,38 @@ class DynamicServiceIntentHandler(IntentHandler):
@cached_property
def slot_schema(self) -> dict:
"""Return a slot schema."""
domain_validator = (
vol.In(list(self.required_domains)) if self.required_domains else cv.string
)
slot_schema = {
vol.Any("name", "area", "floor"): non_empty_string,
vol.Optional("domain"): vol.All(cv.ensure_list, [cv.string]),
vol.Optional("device_class"): vol.All(cv.ensure_list, [cv.string]),
vol.Optional("preferred_area_id"): cv.string,
vol.Optional("preferred_floor_id"): cv.string,
vol.Optional("domain"): vol.All(cv.ensure_list, [domain_validator]),
}
if self.device_classes:
# The typical way to match enums is with vol.Coerce, but we build a
# flat list to make the API simpler to describe programmatically
flattened_device_classes = vol.In(
[
device_class.value
for device_class_enum in self.device_classes
for device_class in device_class_enum
]
)
slot_schema.update(
{
vol.Optional("device_class"): vol.All(
cv.ensure_list,
[flattened_device_classes],
)
}
)
slot_schema.update(
{
vol.Optional("preferred_area_id"): cv.string,
vol.Optional("preferred_floor_id"): cv.string,
}
)
if self.required_slots:
slot_schema.update(
@@ -910,9 +937,6 @@ class DynamicServiceIntentHandler(IntentHandler):
if "domain" in slots:
domains = set(slots["domain"]["value"])
if self.required_domains:
# Must be a subset of intent's required domain(s)
domains.intersection_update(self.required_domains)
if "device_class" in slots:
device_classes = set(slots["device_class"]["value"])
@@ -1120,6 +1144,7 @@ class ServiceIntentHandler(DynamicServiceIntentHandler):
required_states: set[str] | None = None,
description: str | None = None,
platforms: set[str] | None = None,
device_classes: set[type[StrEnum]] | None = None,
) -> None:
"""Create service handler."""
super().__init__(
@@ -1132,6 +1157,7 @@ class ServiceIntentHandler(DynamicServiceIntentHandler):
required_states=required_states,
description=description,
platforms=platforms,
device_classes=device_classes,
)
self.domain = domain
self.service = service