mirror of
https://github.com/home-assistant/core.git
synced 2026-02-15 07:36:16 +00:00
Allow template in query in sql (#150287)
This commit is contained in:
@@ -49,7 +49,9 @@ QUERY_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_COLUMN_NAME): cv.string,
|
||||
vol.Required(CONF_NAME): cv.template,
|
||||
vol.Required(CONF_QUERY): vol.All(cv.string, validate_sql_select),
|
||||
vol.Required(CONF_QUERY): vol.All(
|
||||
cv.template, ValueTemplate.from_template, validate_sql_select
|
||||
),
|
||||
vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string,
|
||||
vol.Optional(CONF_VALUE_TEMPLATE): vol.All(
|
||||
cv.template, ValueTemplate.from_template
|
||||
|
||||
@@ -9,8 +9,6 @@ import sqlalchemy
|
||||
from sqlalchemy.engine import Engine, Result
|
||||
from sqlalchemy.exc import MultipleResultsFound, NoSuchColumnError, SQLAlchemyError
|
||||
from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
||||
import sqlparse
|
||||
from sqlparse.exceptions import SQLParseError
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.recorder import CONF_DB_URL, get_instance
|
||||
@@ -31,21 +29,28 @@ from homeassistant.const import (
|
||||
CONF_UNIT_OF_MEASUREMENT,
|
||||
CONF_VALUE_TEMPLATE,
|
||||
)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import async_get_hass, callback
|
||||
from homeassistant.data_entry_flow import section
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.helpers import selector
|
||||
|
||||
from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
|
||||
from .util import resolve_db_url
|
||||
from .util import (
|
||||
EmptyQueryError,
|
||||
InvalidSqlQuery,
|
||||
MultipleQueryError,
|
||||
NotSelectQueryError,
|
||||
UnknownQueryTypeError,
|
||||
check_and_render_sql_query,
|
||||
resolve_db_url,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OPTIONS_SCHEMA: vol.Schema = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_QUERY): selector.TextSelector(
|
||||
selector.TextSelectorConfig(multiline=True)
|
||||
),
|
||||
vol.Required(CONF_QUERY): selector.TemplateSelector(),
|
||||
vol.Required(CONF_COLUMN_NAME): selector.TextSelector(),
|
||||
vol.Required(CONF_ADVANCED_OPTIONS): section(
|
||||
vol.Schema(
|
||||
@@ -89,14 +94,12 @@ CONFIG_SCHEMA: vol.Schema = vol.Schema(
|
||||
|
||||
def validate_sql_select(value: str) -> str:
|
||||
"""Validate that value is a SQL SELECT query."""
|
||||
if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1:
|
||||
raise MultipleResultsFound
|
||||
if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN":
|
||||
raise ValueError
|
||||
if query_type != "SELECT":
|
||||
_LOGGER.debug("The SQL query %s is of type %s", query, query_type)
|
||||
raise SQLParseError
|
||||
return str(query[0])
|
||||
hass = async_get_hass()
|
||||
try:
|
||||
return check_and_render_sql_query(hass, value)
|
||||
except (TemplateError, InvalidSqlQuery) as err:
|
||||
_LOGGER.debug("Invalid query '%s' results in '%s'", value, err.args[0])
|
||||
raise
|
||||
|
||||
|
||||
def validate_db_connection(db_url: str) -> bool:
|
||||
@@ -138,7 +141,7 @@ def validate_query(db_url: str, query: str, column: str) -> bool:
|
||||
if sess:
|
||||
sess.close()
|
||||
engine.dispose()
|
||||
raise ValueError(error) from error
|
||||
raise InvalidSqlQuery from error
|
||||
|
||||
for res in result.mappings():
|
||||
if column not in res:
|
||||
@@ -224,13 +227,13 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
except NoSuchColumnError:
|
||||
errors["column"] = "column_invalid"
|
||||
description_placeholders = {"column": column}
|
||||
except MultipleResultsFound:
|
||||
except (MultipleResultsFound, MultipleQueryError):
|
||||
errors["query"] = "multiple_queries"
|
||||
except SQLAlchemyError:
|
||||
errors["db_url"] = "db_url_invalid"
|
||||
except SQLParseError:
|
||||
except (NotSelectQueryError, UnknownQueryTypeError):
|
||||
errors["query"] = "query_no_read_only"
|
||||
except ValueError as err:
|
||||
except (TemplateError, EmptyQueryError, InvalidSqlQuery) as err:
|
||||
_LOGGER.debug("Invalid query: %s", err)
|
||||
errors["query"] = "query_invalid"
|
||||
|
||||
@@ -282,13 +285,13 @@ class SQLOptionsFlowHandler(OptionsFlowWithReload):
|
||||
except NoSuchColumnError:
|
||||
errors["column"] = "column_invalid"
|
||||
description_placeholders = {"column": column}
|
||||
except MultipleResultsFound:
|
||||
except (MultipleResultsFound, MultipleQueryError):
|
||||
errors["query"] = "multiple_queries"
|
||||
except SQLAlchemyError:
|
||||
errors["db_url"] = "db_url_invalid"
|
||||
except SQLParseError:
|
||||
except (NotSelectQueryError, UnknownQueryTypeError):
|
||||
errors["query"] = "query_no_read_only"
|
||||
except ValueError as err:
|
||||
except (TemplateError, EmptyQueryError, InvalidSqlQuery) as err:
|
||||
_LOGGER.debug("Invalid query: %s", err)
|
||||
errors["query"] = "query_invalid"
|
||||
else:
|
||||
|
||||
@@ -22,7 +22,7 @@ from homeassistant.const import (
|
||||
MATCH_ALL,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.exceptions import PlatformNotReady, TemplateError
|
||||
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
||||
from homeassistant.helpers.entity_platform import (
|
||||
AddConfigEntryEntitiesCallback,
|
||||
@@ -40,7 +40,9 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
|
||||
from .util import (
|
||||
InvalidSqlQuery,
|
||||
async_create_sessionmaker,
|
||||
check_and_render_sql_query,
|
||||
convert_value,
|
||||
generate_lambda_stmt,
|
||||
redact_credentials,
|
||||
@@ -81,7 +83,7 @@ async def async_setup_platform(
|
||||
return
|
||||
|
||||
name: Template = conf[CONF_NAME]
|
||||
query_str: str = conf[CONF_QUERY]
|
||||
query_template: ValueTemplate = conf[CONF_QUERY]
|
||||
value_template: ValueTemplate | None = conf.get(CONF_VALUE_TEMPLATE)
|
||||
column_name: str = conf[CONF_COLUMN_NAME]
|
||||
unique_id: str | None = conf.get(CONF_UNIQUE_ID)
|
||||
@@ -96,7 +98,7 @@ async def async_setup_platform(
|
||||
await async_setup_sensor(
|
||||
hass,
|
||||
trigger_entity_config,
|
||||
query_str,
|
||||
query_template,
|
||||
column_name,
|
||||
value_template,
|
||||
unique_id,
|
||||
@@ -119,6 +121,13 @@ async def async_setup_entry(
|
||||
template: str | None = entry.options[CONF_ADVANCED_OPTIONS].get(CONF_VALUE_TEMPLATE)
|
||||
column_name: str = entry.options[CONF_COLUMN_NAME]
|
||||
|
||||
query_template: ValueTemplate | None = None
|
||||
try:
|
||||
query_template = ValueTemplate(query_str, hass)
|
||||
query_template.ensure_valid()
|
||||
except TemplateError as err:
|
||||
raise PlatformNotReady("Invalid SQL query template") from err
|
||||
|
||||
value_template: ValueTemplate | None = None
|
||||
if template is not None:
|
||||
try:
|
||||
@@ -137,7 +146,7 @@ async def async_setup_entry(
|
||||
await async_setup_sensor(
|
||||
hass,
|
||||
trigger_entity_config,
|
||||
query_str,
|
||||
query_template,
|
||||
column_name,
|
||||
value_template,
|
||||
entry.entry_id,
|
||||
@@ -150,7 +159,7 @@ async def async_setup_entry(
|
||||
async def async_setup_sensor(
|
||||
hass: HomeAssistant,
|
||||
trigger_entity_config: ConfigType,
|
||||
query_str: str,
|
||||
query_template: ValueTemplate,
|
||||
column_name: str,
|
||||
value_template: ValueTemplate | None,
|
||||
unique_id: str | None,
|
||||
@@ -166,22 +175,25 @@ async def async_setup_sensor(
|
||||
) = await async_create_sessionmaker(hass, db_url)
|
||||
if sessmaker is None:
|
||||
return
|
||||
validate_query(hass, query_str, uses_recorder_db, unique_id)
|
||||
validate_query(hass, query_template, uses_recorder_db, unique_id)
|
||||
|
||||
query_str = check_and_render_sql_query(hass, query_template)
|
||||
upper_query = query_str.upper()
|
||||
# MSSQL uses TOP and not LIMIT
|
||||
mod_query_template = query_template
|
||||
if not ("LIMIT" in upper_query or "SELECT TOP" in upper_query):
|
||||
if "mssql" in db_url:
|
||||
query_str = upper_query.replace("SELECT", "SELECT TOP 1")
|
||||
_query = query_template.template.replace("SELECT", "SELECT TOP 1")
|
||||
else:
|
||||
query_str = query_str.replace(";", "") + " LIMIT 1;"
|
||||
_query = query_template.template.replace(";", "") + " LIMIT 1;"
|
||||
mod_query_template = ValueTemplate(_query, hass)
|
||||
|
||||
async_add_entities(
|
||||
[
|
||||
SQLSensor(
|
||||
trigger_entity_config,
|
||||
sessmaker,
|
||||
query_str,
|
||||
mod_query_template,
|
||||
column_name,
|
||||
value_template,
|
||||
yaml,
|
||||
@@ -200,7 +212,7 @@ class SQLSensor(ManualTriggerSensorEntity):
|
||||
self,
|
||||
trigger_entity_config: ConfigType,
|
||||
sessmaker: scoped_session,
|
||||
query: str,
|
||||
query: ValueTemplate,
|
||||
column: str,
|
||||
value_template: ValueTemplate | None,
|
||||
yaml: bool,
|
||||
@@ -214,7 +226,6 @@ class SQLSensor(ManualTriggerSensorEntity):
|
||||
self.sessionmaker = sessmaker
|
||||
self._attr_extra_state_attributes = {}
|
||||
self._use_database_executor = use_database_executor
|
||||
self._lambda_stmt = generate_lambda_stmt(query)
|
||||
if not yaml and (unique_id := trigger_entity_config.get(CONF_UNIQUE_ID)):
|
||||
self._attr_name = None
|
||||
self._attr_has_entity_name = True
|
||||
@@ -255,11 +266,22 @@ class SQLSensor(ManualTriggerSensorEntity):
|
||||
self._attr_extra_state_attributes = {}
|
||||
sess: scoped_session = self.sessionmaker()
|
||||
try:
|
||||
result: Result = sess.execute(self._lambda_stmt)
|
||||
rendered_query = check_and_render_sql_query(self.hass, self._query)
|
||||
_lambda_stmt = generate_lambda_stmt(rendered_query)
|
||||
result: Result = sess.execute(_lambda_stmt)
|
||||
except (TemplateError, InvalidSqlQuery) as err:
|
||||
_LOGGER.error(
|
||||
"Error rendering query %s: %s",
|
||||
redact_credentials(self._query.template),
|
||||
redact_credentials(str(err)),
|
||||
)
|
||||
sess.rollback()
|
||||
sess.close()
|
||||
return
|
||||
except SQLAlchemyError as err:
|
||||
_LOGGER.error(
|
||||
"Error executing query %s: %s",
|
||||
self._query,
|
||||
rendered_query,
|
||||
redact_credentials(str(err)),
|
||||
)
|
||||
sess.rollback()
|
||||
@@ -267,7 +289,7 @@ class SQLSensor(ManualTriggerSensorEntity):
|
||||
return
|
||||
|
||||
for res in result.mappings():
|
||||
_LOGGER.debug("Query %s result in %s", self._query, res.items())
|
||||
_LOGGER.debug("Query %s result in %s", rendered_query, res.items())
|
||||
data = res[self._column_name]
|
||||
for key, value in res.items():
|
||||
self._attr_extra_state_attributes[key] = convert_value(value)
|
||||
@@ -287,6 +309,6 @@ class SQLSensor(ManualTriggerSensorEntity):
|
||||
self._attr_native_value = data
|
||||
|
||||
if data is None:
|
||||
_LOGGER.warning("%s returned no results", self._query)
|
||||
_LOGGER.warning("%s returned no results", rendered_query)
|
||||
|
||||
sess.close()
|
||||
|
||||
@@ -19,11 +19,13 @@ from homeassistant.core import (
|
||||
)
|
||||
from homeassistant.exceptions import ServiceValidationError
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.trigger_template_entity import ValueTemplate
|
||||
from homeassistant.util.json import JsonValueType
|
||||
|
||||
from .const import CONF_QUERY, DOMAIN
|
||||
from .util import (
|
||||
async_create_sessionmaker,
|
||||
check_and_render_sql_query,
|
||||
convert_value,
|
||||
generate_lambda_stmt,
|
||||
redact_credentials,
|
||||
@@ -37,7 +39,9 @@ _LOGGER = logging.getLogger(__name__)
|
||||
SERVICE_QUERY = "query"
|
||||
SERVICE_QUERY_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_QUERY): vol.All(cv.string, validate_sql_select),
|
||||
vol.Required(CONF_QUERY): vol.All(
|
||||
cv.template, ValueTemplate.from_template, validate_sql_select
|
||||
),
|
||||
vol.Optional(CONF_DB_URL): cv.string,
|
||||
}
|
||||
)
|
||||
@@ -72,8 +76,9 @@ async def _async_query_service(
|
||||
def _execute_and_convert_query() -> list[JsonValueType]:
|
||||
"""Execute the query and return the results with converted types."""
|
||||
sess: Session = sessmaker()
|
||||
rendered_query = check_and_render_sql_query(call.hass, query_str)
|
||||
try:
|
||||
result: Result = sess.execute(generate_lambda_stmt(query_str))
|
||||
result: Result = sess.execute(generate_lambda_stmt(rendered_query))
|
||||
except SQLAlchemyError as err:
|
||||
_LOGGER.debug(
|
||||
"Error executing query %s: %s",
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
"db_url_invalid": "Database URL invalid",
|
||||
"multiple_queries": "Multiple SQL queries are not supported",
|
||||
"query_invalid": "SQL query invalid",
|
||||
"query_no_read_only": "SQL query must be read-only"
|
||||
"query_no_read_only": "SQL query is not a read-only SELECT query or it's of an unknown type"
|
||||
},
|
||||
"step": {
|
||||
"options": {
|
||||
|
||||
@@ -19,7 +19,9 @@ import voluptuous as vol
|
||||
from homeassistant.components.recorder import SupportedDialect, get_instance
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.helpers import issue_registry as ir
|
||||
from homeassistant.helpers.template import Template
|
||||
|
||||
from .const import DB_URL_RE, DOMAIN
|
||||
from .models import SQLData
|
||||
@@ -44,16 +46,14 @@ def resolve_db_url(hass: HomeAssistant, db_url: str | None) -> str:
|
||||
return get_instance(hass).db_url
|
||||
|
||||
|
||||
def validate_sql_select(value: str) -> str:
|
||||
def validate_sql_select(value: Template) -> Template:
|
||||
"""Validate that value is a SQL SELECT query."""
|
||||
if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1:
|
||||
raise vol.Invalid("Multiple SQL queries are not supported")
|
||||
if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN":
|
||||
raise vol.Invalid("Invalid SQL query")
|
||||
if query_type != "SELECT":
|
||||
_LOGGER.debug("The SQL query %s is of type %s", query, query_type)
|
||||
raise vol.Invalid("Only SELECT queries allowed")
|
||||
return str(query[0])
|
||||
try:
|
||||
assert value.hass
|
||||
check_and_render_sql_query(value.hass, value)
|
||||
except (TemplateError, InvalidSqlQuery) as err:
|
||||
raise vol.Invalid(str(err)) from err
|
||||
return value
|
||||
|
||||
|
||||
async def async_create_sessionmaker(
|
||||
@@ -113,7 +113,7 @@ async def async_create_sessionmaker(
|
||||
|
||||
def validate_query(
|
||||
hass: HomeAssistant,
|
||||
query_str: str,
|
||||
query_template: str | Template,
|
||||
uses_recorder_db: bool,
|
||||
unique_id: str | None = None,
|
||||
) -> None:
|
||||
@@ -121,7 +121,7 @@ def validate_query(
|
||||
|
||||
Args:
|
||||
hass: The Home Assistant instance.
|
||||
query_str: The SQL query string to be validated.
|
||||
query_template: The SQL query string to be validated.
|
||||
uses_recorder_db: A boolean indicating if the query is against the recorder database.
|
||||
unique_id: The unique ID of the entity, used for creating issue registry keys.
|
||||
|
||||
@@ -131,6 +131,10 @@ def validate_query(
|
||||
"""
|
||||
if not uses_recorder_db:
|
||||
return
|
||||
if isinstance(query_template, Template):
|
||||
query_str = query_template.async_render()
|
||||
else:
|
||||
query_str = Template(query_template, hass).async_render()
|
||||
redacted_query = redact_credentials(query_str)
|
||||
|
||||
issue_key = unique_id if unique_id else redacted_query
|
||||
@@ -239,3 +243,49 @@ def convert_value(value: Any) -> Any:
|
||||
return f"0x{value.hex()}"
|
||||
case _:
|
||||
return value
|
||||
|
||||
|
||||
def check_and_render_sql_query(hass: HomeAssistant, query: Template | str) -> str:
|
||||
"""Check and render SQL query."""
|
||||
if isinstance(query, str):
|
||||
query = query.strip()
|
||||
if not query:
|
||||
raise EmptyQueryError("Query cannot be empty")
|
||||
query = Template(query, hass=hass)
|
||||
|
||||
# Raises TemplateError if template is invalid
|
||||
query.ensure_valid()
|
||||
rendered_query: str = query.async_render()
|
||||
|
||||
if len(rendered_queries := sqlparse.parse(rendered_query.lstrip().lstrip(";"))) > 1:
|
||||
raise MultipleQueryError("Multiple SQL statements are not allowed")
|
||||
if (
|
||||
len(rendered_queries) == 0
|
||||
or (query_type := rendered_queries[0].get_type()) == "UNKNOWN"
|
||||
):
|
||||
raise UnknownQueryTypeError("SQL query is empty or unknown type")
|
||||
if query_type != "SELECT":
|
||||
_LOGGER.debug("The SQL query %s is of type %s", rendered_query, query_type)
|
||||
raise NotSelectQueryError("SQL query must be of type SELECT")
|
||||
|
||||
return str(rendered_queries[0])
|
||||
|
||||
|
||||
class InvalidSqlQuery(HomeAssistantError):
|
||||
"""SQL query is invalid error."""
|
||||
|
||||
|
||||
class EmptyQueryError(InvalidSqlQuery):
|
||||
"""SQL query is empty error."""
|
||||
|
||||
|
||||
class MultipleQueryError(InvalidSqlQuery):
|
||||
"""SQL query is multiple error."""
|
||||
|
||||
|
||||
class UnknownQueryTypeError(InvalidSqlQuery):
|
||||
"""SQL query is of unknown type error."""
|
||||
|
||||
|
||||
class NotSelectQueryError(InvalidSqlQuery):
|
||||
"""SQL query is not a SELECT statement error."""
|
||||
|
||||
@@ -44,6 +44,17 @@ ENTRY_CONFIG = {
|
||||
},
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_BLANK_QUERY = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: " ",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
|
||||
CONF_STATE_CLASS: SensorStateClass.TOTAL,
|
||||
},
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_WITH_VALUE_TEMPLATE = {
|
||||
CONF_QUERY: "SELECT 5 as value",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
@@ -53,6 +64,33 @@ ENTRY_CONFIG_WITH_VALUE_TEMPLATE = {
|
||||
},
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_WITH_QUERY_TEMPLATE = {
|
||||
CONF_QUERY: "SELECT {% if states('sensor.input1')=='on' %} 5 {% else %} 6 {% endif %} as value",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_VALUE_TEMPLATE: "{{ value }}",
|
||||
},
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE = {
|
||||
CONF_QUERY: "SELECT {{ 5 as value",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_VALUE_TEMPLATE: "{{ value }}",
|
||||
},
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE_OPT = {
|
||||
CONF_QUERY: "SELECT {{ 5 as value",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_VALUE_TEMPLATE: "{{ value }}",
|
||||
},
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY = {
|
||||
CONF_QUERY: "SELECT 5 FROM as value",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -10,7 +11,7 @@ import pytest
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.recorder import CONF_DB_URL
|
||||
from homeassistant.components.recorder import CONF_DB_URL, Recorder
|
||||
from homeassistant.components.sensor import (
|
||||
CONF_STATE_CLASS,
|
||||
SensorDeviceClass,
|
||||
@@ -29,7 +30,7 @@ from homeassistant.const import (
|
||||
CONF_VALUE_TEMPLATE,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
from homeassistant.data_entry_flow import FlowResultType, InvalidData
|
||||
|
||||
from . import (
|
||||
ENTRY_CONFIG,
|
||||
@@ -48,6 +49,9 @@ from . import (
|
||||
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE,
|
||||
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT,
|
||||
ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT,
|
||||
ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE,
|
||||
ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE_OPT,
|
||||
ENTRY_CONFIG_WITH_QUERY_TEMPLATE,
|
||||
ENTRY_CONFIG_WITH_VALUE_TEMPLATE,
|
||||
)
|
||||
|
||||
@@ -106,7 +110,91 @@ async def test_form_simple(
|
||||
}
|
||||
|
||||
|
||||
async def test_form_with_value_template(hass: HomeAssistant) -> None:
|
||||
async def test_form_with_query_template(
|
||||
recorder_mock: Recorder, hass: HomeAssistant
|
||||
) -> None:
|
||||
"""Test for with query template."""
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["errors"] == {}
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.sql.async_setup_entry",
|
||||
return_value=True,
|
||||
) as mock_setup_entry:
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
DATA_CONFIG,
|
||||
)
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
ENTRY_CONFIG_WITH_QUERY_TEMPLATE,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == "Get Value"
|
||||
assert result["options"] == {
|
||||
CONF_QUERY: "SELECT {% if states('sensor.input1')=='on' %} 5 {% else %} 6 {% endif %} as value",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_VALUE_TEMPLATE: "{{ value }}",
|
||||
},
|
||||
}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_form_with_broken_query_template(
|
||||
recorder_mock: Recorder, hass: HomeAssistant
|
||||
) -> None:
|
||||
"""Test form with broken query template."""
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["errors"] == {}
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
DATA_CONFIG,
|
||||
)
|
||||
message = re.escape("Schema validation failed @ data['query']")
|
||||
with pytest.raises(InvalidData, match=message):
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.sql.async_setup_entry",
|
||||
return_value=True,
|
||||
) as mock_setup_entry:
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
ENTRY_CONFIG_WITH_QUERY_TEMPLATE,
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == "Get Value"
|
||||
assert result["options"] == {
|
||||
CONF_QUERY: "SELECT {% if states('sensor.input1')=='on' %} 5 {% else %} 6 {% endif %} as value",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_VALUE_TEMPLATE: "{{ value }}",
|
||||
},
|
||||
}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_form_with_value_template(
|
||||
recorder_mock: Recorder, hass: HomeAssistant
|
||||
) -> None:
|
||||
"""Test for with value template."""
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
@@ -192,7 +280,7 @@ async def test_flow_fails_invalid_query(hass: HomeAssistant) -> None:
|
||||
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["errors"] == {
|
||||
CONF_QUERY: "query_invalid",
|
||||
CONF_QUERY: "query_no_read_only",
|
||||
}
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
@@ -202,7 +290,7 @@ async def test_flow_fails_invalid_query(hass: HomeAssistant) -> None:
|
||||
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["errors"] == {
|
||||
CONF_QUERY: "query_invalid",
|
||||
CONF_QUERY: "query_no_read_only",
|
||||
}
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
@@ -484,7 +572,7 @@ async def test_options_flow_fails_invalid_query(hass: HomeAssistant) -> None:
|
||||
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["errors"] == {
|
||||
CONF_QUERY: "query_invalid",
|
||||
CONF_QUERY: "query_no_read_only",
|
||||
}
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
@@ -494,9 +582,8 @@ async def test_options_flow_fails_invalid_query(hass: HomeAssistant) -> None:
|
||||
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["errors"] == {
|
||||
CONF_QUERY: "query_invalid",
|
||||
CONF_QUERY: "query_no_read_only",
|
||||
}
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT,
|
||||
@@ -527,6 +614,13 @@ async def test_options_flow_fails_invalid_query(hass: HomeAssistant) -> None:
|
||||
CONF_QUERY: "multiple_queries",
|
||||
}
|
||||
|
||||
message = re.escape("Schema validation failed @ data['query']")
|
||||
with pytest.raises(InvalidData, match=message):
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input=ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE_OPT,
|
||||
)
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={
|
||||
|
||||
@@ -4,6 +4,9 @@ from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.recorder import CONF_DB_URL, Recorder
|
||||
from homeassistant.components.sensor import (
|
||||
CONF_STATE_CLASS,
|
||||
@@ -16,6 +19,7 @@ from homeassistant.components.sql.const import (
|
||||
CONF_QUERY,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.components.sql.util import validate_sql_select
|
||||
from homeassistant.config_entries import SOURCE_USER, ConfigEntryState
|
||||
from homeassistant.const import (
|
||||
CONF_DEVICE_CLASS,
|
||||
@@ -24,6 +28,7 @@ from homeassistant.const import (
|
||||
CONF_VALUE_TEMPLATE,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.template import Template
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import YAML_CONFIG_INVALID, YAML_CONFIG_NO_DB, init_integration
|
||||
@@ -67,6 +72,45 @@ async def test_setup_invalid_config(
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_invalid_query(hass: HomeAssistant) -> None:
|
||||
"""Test invalid query."""
|
||||
with pytest.raises(vol.Invalid, match="SQL query must be of type SELECT"):
|
||||
validate_sql_select(Template("DROP TABLE *", hass))
|
||||
|
||||
with pytest.raises(vol.Invalid, match="SQL query is empty or unknown type"):
|
||||
validate_sql_select(Template("SELECT5 as value", hass))
|
||||
|
||||
with pytest.raises(vol.Invalid, match="SQL query is empty or unknown type"):
|
||||
validate_sql_select(Template(";;", hass))
|
||||
|
||||
|
||||
async def test_query_no_read_only(hass: HomeAssistant) -> None:
|
||||
"""Test query no read only."""
|
||||
with pytest.raises(vol.Invalid, match="SQL query must be of type SELECT"):
|
||||
validate_sql_select(
|
||||
Template("UPDATE states SET state = 999999 WHERE state_id = 11125", hass)
|
||||
)
|
||||
|
||||
|
||||
async def test_query_no_read_only_cte(hass: HomeAssistant) -> None:
|
||||
"""Test query no read only CTE."""
|
||||
with pytest.raises(vol.Invalid, match="SQL query must be of type SELECT"):
|
||||
validate_sql_select(
|
||||
Template(
|
||||
"WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
|
||||
hass,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def test_multiple_queries(hass: HomeAssistant) -> None:
|
||||
"""Test multiple queries."""
|
||||
with pytest.raises(vol.Invalid, match="Multiple SQL statements are not allowed"):
|
||||
validate_sql_select(
|
||||
Template("SELECT 5 as value; UPDATE states SET state = 10;", hass)
|
||||
)
|
||||
|
||||
|
||||
async def test_migration_from_future(
|
||||
recorder_mock: Recorder, hass: HomeAssistant
|
||||
) -> None:
|
||||
|
||||
@@ -39,7 +39,6 @@ from homeassistant.const import (
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import issue_registry as ir
|
||||
from homeassistant.helpers.entity_platform import async_get_platforms
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
@@ -109,6 +108,33 @@ async def test_query_value_template(
|
||||
}
|
||||
|
||||
|
||||
async def test_template_query(
|
||||
recorder_mock: Recorder,
|
||||
hass: HomeAssistant,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
) -> None:
|
||||
"""Test the SQL sensor with a query template."""
|
||||
options = {
|
||||
CONF_QUERY: "SELECT {% if states('sensor.input1')=='on' %} 5 {% else %} 6 {% endif %} as value",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_VALUE_TEMPLATE: "{{ value | int }}",
|
||||
},
|
||||
}
|
||||
await init_integration(hass, title="count_tables", options=options)
|
||||
|
||||
state = hass.states.get("sensor.count_tables")
|
||||
assert state.state == "6"
|
||||
|
||||
hass.states.async_set("sensor.input1", "on")
|
||||
freezer.tick(timedelta(minutes=1))
|
||||
async_fire_time_changed(hass)
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
|
||||
state = hass.states.get("sensor.count_tables")
|
||||
assert state.state == "5"
|
||||
|
||||
|
||||
async def test_query_value_template_invalid(
|
||||
recorder_mock: Recorder, hass: HomeAssistant
|
||||
) -> None:
|
||||
@@ -124,6 +150,59 @@ async def test_query_value_template_invalid(
|
||||
assert state.state == "5.01"
|
||||
|
||||
|
||||
async def test_broken_template_query(
|
||||
recorder_mock: Recorder,
|
||||
hass: HomeAssistant,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
) -> None:
|
||||
"""Test the SQL sensor with a query template which is broken."""
|
||||
options = {
|
||||
CONF_QUERY: "SELECT {{ 5 as value",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_VALUE_TEMPLATE: "{{ value | int }}",
|
||||
},
|
||||
}
|
||||
await init_integration(hass, title="count_tables", options=options)
|
||||
|
||||
state = hass.states.get("sensor.count_tables")
|
||||
assert not state
|
||||
|
||||
|
||||
async def test_broken_template_query_2(
|
||||
recorder_mock: Recorder,
|
||||
hass: HomeAssistant,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test the SQL sensor with a query template."""
|
||||
hass.states.async_set("sensor.input1", "5")
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
|
||||
options = {
|
||||
CONF_QUERY: "SELECT {{ states.sensor.input1.state | int / 1000}} as value",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
}
|
||||
await init_integration(hass, title="count_tables", options=options)
|
||||
|
||||
state = hass.states.get("sensor.count_tables")
|
||||
assert state.state == "0.005"
|
||||
|
||||
hass.states.async_set("sensor.input1", "on")
|
||||
freezer.tick(timedelta(minutes=1))
|
||||
async_fire_time_changed(hass)
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
|
||||
state = hass.states.get("sensor.count_tables")
|
||||
assert state.state == "0.005"
|
||||
assert (
|
||||
"Error rendering query SELECT {{ states.sensor.input1.state | int / 1000}} as value"
|
||||
" LIMIT 1;: ValueError: Template error: int got invalid input 'on' when rendering"
|
||||
" template 'SELECT {{ states.sensor.input1.state | int / 1000}} as value LIMIT 1;'"
|
||||
" but no default was specified" in caplog.text
|
||||
)
|
||||
|
||||
|
||||
async def test_query_limit(recorder_mock: Recorder, hass: HomeAssistant) -> None:
|
||||
"""Test the SQL sensor with a query containing 'LIMIT' in lowercase."""
|
||||
options = {
|
||||
@@ -641,17 +720,14 @@ async def test_query_recover_from_rollback(
|
||||
CONF_UNIQUE_ID: "very_unique_id",
|
||||
}
|
||||
await init_integration(hass, title="Select value SQL query", options=options)
|
||||
platforms = async_get_platforms(hass, "sql")
|
||||
sql_entity = platforms[0].entities["sensor.select_value_sql_query"]
|
||||
|
||||
state = hass.states.get("sensor.select_value_sql_query")
|
||||
assert state.state == "5"
|
||||
assert state.attributes["value"] == 5
|
||||
|
||||
with patch.object(
|
||||
sql_entity,
|
||||
"_lambda_stmt",
|
||||
generate_lambda_stmt("Faulty syntax create operational issue"),
|
||||
with patch(
|
||||
"homeassistant.components.sql.sensor.generate_lambda_stmt",
|
||||
return_value=generate_lambda_stmt("Faulty syntax create operational issue"),
|
||||
):
|
||||
freezer.tick(timedelta(minutes=1))
|
||||
async_fire_time_changed(hass)
|
||||
|
||||
@@ -153,7 +153,7 @@ async def test_query_service_invalid_query_not_select(
|
||||
await async_setup_component(hass, DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with pytest.raises(vol.Invalid, match="Only SELECT queries allowed"):
|
||||
with pytest.raises(vol.Invalid, match="SQL query must be of type SELECT"):
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_QUERY,
|
||||
@@ -171,7 +171,7 @@ async def test_query_service_sqlalchemy_error(
|
||||
await async_setup_component(hass, DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with pytest.raises(MultipleInvalid, match="Invalid SQL query"):
|
||||
with pytest.raises(MultipleInvalid, match="SQL query is empty or unknown type"):
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_QUERY,
|
||||
|
||||
@@ -13,6 +13,7 @@ from homeassistant.components.sql.util import (
|
||||
validate_sql_select,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.template import Template
|
||||
|
||||
|
||||
async def test_resolve_db_url_when_none_configured(
|
||||
@@ -39,27 +40,27 @@ async def test_resolve_db_url_when_configured(hass: HomeAssistant) -> None:
|
||||
[
|
||||
(
|
||||
"DROP TABLE *",
|
||||
"Only SELECT queries allowed",
|
||||
"SQL query must be of type SELECT",
|
||||
),
|
||||
(
|
||||
"SELECT5 as value",
|
||||
"Invalid SQL query",
|
||||
"SQL query is empty or unknown type",
|
||||
),
|
||||
(
|
||||
";;",
|
||||
"Invalid SQL query",
|
||||
"SQL query is empty or unknown type",
|
||||
),
|
||||
(
|
||||
"UPDATE states SET state = 999999 WHERE state_id = 11125",
|
||||
"Only SELECT queries allowed",
|
||||
"SQL query must be of type SELECT",
|
||||
),
|
||||
(
|
||||
"WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
|
||||
"Only SELECT queries allowed",
|
||||
"SQL query must be of type SELECT",
|
||||
),
|
||||
(
|
||||
"SELECT 5 as value; UPDATE states SET state = 10;",
|
||||
"Multiple SQL queries are not supported",
|
||||
"Multiple SQL statements are not allowed",
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -70,7 +71,7 @@ async def test_invalid_sql_queries(
|
||||
) -> None:
|
||||
"""Test that various invalid or disallowed SQL queries raise the correct exception."""
|
||||
with pytest.raises(vol.Invalid, match=expected_error_message):
|
||||
validate_sql_select(sql_query)
|
||||
validate_sql_select(Template(sql_query, hass))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user