1
0
mirror of https://github.com/home-assistant/core.git synced 2026-02-15 07:36:16 +00:00

Allow template in query in sql (#150287)

This commit is contained in:
G Johansson
2025-11-06 17:11:46 +01:00
committed by GitHub
parent 2ddf55a60d
commit 67ccdd36fb
12 changed files with 411 additions and 76 deletions

View File

@@ -49,7 +49,9 @@ QUERY_SCHEMA = vol.Schema(
{
vol.Required(CONF_COLUMN_NAME): cv.string,
vol.Required(CONF_NAME): cv.template,
vol.Required(CONF_QUERY): vol.All(cv.string, validate_sql_select),
vol.Required(CONF_QUERY): vol.All(
cv.template, ValueTemplate.from_template, validate_sql_select
),
vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string,
vol.Optional(CONF_VALUE_TEMPLATE): vol.All(
cv.template, ValueTemplate.from_template

View File

@@ -9,8 +9,6 @@ import sqlalchemy
from sqlalchemy.engine import Engine, Result
from sqlalchemy.exc import MultipleResultsFound, NoSuchColumnError, SQLAlchemyError
from sqlalchemy.orm import Session, scoped_session, sessionmaker
import sqlparse
from sqlparse.exceptions import SQLParseError
import voluptuous as vol
from homeassistant.components.recorder import CONF_DB_URL, get_instance
@@ -31,21 +29,28 @@ from homeassistant.const import (
CONF_UNIT_OF_MEASUREMENT,
CONF_VALUE_TEMPLATE,
)
from homeassistant.core import callback
from homeassistant.core import async_get_hass, callback
from homeassistant.data_entry_flow import section
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import selector
from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
from .util import resolve_db_url
from .util import (
EmptyQueryError,
InvalidSqlQuery,
MultipleQueryError,
NotSelectQueryError,
UnknownQueryTypeError,
check_and_render_sql_query,
resolve_db_url,
)
_LOGGER = logging.getLogger(__name__)
OPTIONS_SCHEMA: vol.Schema = vol.Schema(
{
vol.Required(CONF_QUERY): selector.TextSelector(
selector.TextSelectorConfig(multiline=True)
),
vol.Required(CONF_QUERY): selector.TemplateSelector(),
vol.Required(CONF_COLUMN_NAME): selector.TextSelector(),
vol.Required(CONF_ADVANCED_OPTIONS): section(
vol.Schema(
@@ -89,14 +94,12 @@ CONFIG_SCHEMA: vol.Schema = vol.Schema(
def validate_sql_select(value: str) -> str:
"""Validate that value is a SQL SELECT query."""
if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1:
raise MultipleResultsFound
if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN":
raise ValueError
if query_type != "SELECT":
_LOGGER.debug("The SQL query %s is of type %s", query, query_type)
raise SQLParseError
return str(query[0])
hass = async_get_hass()
try:
return check_and_render_sql_query(hass, value)
except (TemplateError, InvalidSqlQuery) as err:
_LOGGER.debug("Invalid query '%s' results in '%s'", value, err.args[0])
raise
def validate_db_connection(db_url: str) -> bool:
@@ -138,7 +141,7 @@ def validate_query(db_url: str, query: str, column: str) -> bool:
if sess:
sess.close()
engine.dispose()
raise ValueError(error) from error
raise InvalidSqlQuery from error
for res in result.mappings():
if column not in res:
@@ -224,13 +227,13 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN):
except NoSuchColumnError:
errors["column"] = "column_invalid"
description_placeholders = {"column": column}
except MultipleResultsFound:
except (MultipleResultsFound, MultipleQueryError):
errors["query"] = "multiple_queries"
except SQLAlchemyError:
errors["db_url"] = "db_url_invalid"
except SQLParseError:
except (NotSelectQueryError, UnknownQueryTypeError):
errors["query"] = "query_no_read_only"
except ValueError as err:
except (TemplateError, EmptyQueryError, InvalidSqlQuery) as err:
_LOGGER.debug("Invalid query: %s", err)
errors["query"] = "query_invalid"
@@ -282,13 +285,13 @@ class SQLOptionsFlowHandler(OptionsFlowWithReload):
except NoSuchColumnError:
errors["column"] = "column_invalid"
description_placeholders = {"column": column}
except MultipleResultsFound:
except (MultipleResultsFound, MultipleQueryError):
errors["query"] = "multiple_queries"
except SQLAlchemyError:
errors["db_url"] = "db_url_invalid"
except SQLParseError:
except (NotSelectQueryError, UnknownQueryTypeError):
errors["query"] = "query_no_read_only"
except ValueError as err:
except (TemplateError, EmptyQueryError, InvalidSqlQuery) as err:
_LOGGER.debug("Invalid query: %s", err)
errors["query"] = "query_invalid"
else:

View File

@@ -22,7 +22,7 @@ from homeassistant.const import (
MATCH_ALL,
)
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError
from homeassistant.exceptions import PlatformNotReady, TemplateError
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.entity_platform import (
AddConfigEntryEntitiesCallback,
@@ -40,7 +40,9 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
from .util import (
InvalidSqlQuery,
async_create_sessionmaker,
check_and_render_sql_query,
convert_value,
generate_lambda_stmt,
redact_credentials,
@@ -81,7 +83,7 @@ async def async_setup_platform(
return
name: Template = conf[CONF_NAME]
query_str: str = conf[CONF_QUERY]
query_template: ValueTemplate = conf[CONF_QUERY]
value_template: ValueTemplate | None = conf.get(CONF_VALUE_TEMPLATE)
column_name: str = conf[CONF_COLUMN_NAME]
unique_id: str | None = conf.get(CONF_UNIQUE_ID)
@@ -96,7 +98,7 @@ async def async_setup_platform(
await async_setup_sensor(
hass,
trigger_entity_config,
query_str,
query_template,
column_name,
value_template,
unique_id,
@@ -119,6 +121,13 @@ async def async_setup_entry(
template: str | None = entry.options[CONF_ADVANCED_OPTIONS].get(CONF_VALUE_TEMPLATE)
column_name: str = entry.options[CONF_COLUMN_NAME]
query_template: ValueTemplate | None = None
try:
query_template = ValueTemplate(query_str, hass)
query_template.ensure_valid()
except TemplateError as err:
raise PlatformNotReady("Invalid SQL query template") from err
value_template: ValueTemplate | None = None
if template is not None:
try:
@@ -137,7 +146,7 @@ async def async_setup_entry(
await async_setup_sensor(
hass,
trigger_entity_config,
query_str,
query_template,
column_name,
value_template,
entry.entry_id,
@@ -150,7 +159,7 @@ async def async_setup_entry(
async def async_setup_sensor(
hass: HomeAssistant,
trigger_entity_config: ConfigType,
query_str: str,
query_template: ValueTemplate,
column_name: str,
value_template: ValueTemplate | None,
unique_id: str | None,
@@ -166,22 +175,25 @@ async def async_setup_sensor(
) = await async_create_sessionmaker(hass, db_url)
if sessmaker is None:
return
validate_query(hass, query_str, uses_recorder_db, unique_id)
validate_query(hass, query_template, uses_recorder_db, unique_id)
query_str = check_and_render_sql_query(hass, query_template)
upper_query = query_str.upper()
# MSSQL uses TOP and not LIMIT
mod_query_template = query_template
if not ("LIMIT" in upper_query or "SELECT TOP" in upper_query):
if "mssql" in db_url:
query_str = upper_query.replace("SELECT", "SELECT TOP 1")
_query = query_template.template.replace("SELECT", "SELECT TOP 1")
else:
query_str = query_str.replace(";", "") + " LIMIT 1;"
_query = query_template.template.replace(";", "") + " LIMIT 1;"
mod_query_template = ValueTemplate(_query, hass)
async_add_entities(
[
SQLSensor(
trigger_entity_config,
sessmaker,
query_str,
mod_query_template,
column_name,
value_template,
yaml,
@@ -200,7 +212,7 @@ class SQLSensor(ManualTriggerSensorEntity):
self,
trigger_entity_config: ConfigType,
sessmaker: scoped_session,
query: str,
query: ValueTemplate,
column: str,
value_template: ValueTemplate | None,
yaml: bool,
@@ -214,7 +226,6 @@ class SQLSensor(ManualTriggerSensorEntity):
self.sessionmaker = sessmaker
self._attr_extra_state_attributes = {}
self._use_database_executor = use_database_executor
self._lambda_stmt = generate_lambda_stmt(query)
if not yaml and (unique_id := trigger_entity_config.get(CONF_UNIQUE_ID)):
self._attr_name = None
self._attr_has_entity_name = True
@@ -255,11 +266,22 @@ class SQLSensor(ManualTriggerSensorEntity):
self._attr_extra_state_attributes = {}
sess: scoped_session = self.sessionmaker()
try:
result: Result = sess.execute(self._lambda_stmt)
rendered_query = check_and_render_sql_query(self.hass, self._query)
_lambda_stmt = generate_lambda_stmt(rendered_query)
result: Result = sess.execute(_lambda_stmt)
except (TemplateError, InvalidSqlQuery) as err:
_LOGGER.error(
"Error rendering query %s: %s",
redact_credentials(self._query.template),
redact_credentials(str(err)),
)
sess.rollback()
sess.close()
return
except SQLAlchemyError as err:
_LOGGER.error(
"Error executing query %s: %s",
self._query,
rendered_query,
redact_credentials(str(err)),
)
sess.rollback()
@@ -267,7 +289,7 @@ class SQLSensor(ManualTriggerSensorEntity):
return
for res in result.mappings():
_LOGGER.debug("Query %s result in %s", self._query, res.items())
_LOGGER.debug("Query %s result in %s", rendered_query, res.items())
data = res[self._column_name]
for key, value in res.items():
self._attr_extra_state_attributes[key] = convert_value(value)
@@ -287,6 +309,6 @@ class SQLSensor(ManualTriggerSensorEntity):
self._attr_native_value = data
if data is None:
_LOGGER.warning("%s returned no results", self._query)
_LOGGER.warning("%s returned no results", rendered_query)
sess.close()

View File

@@ -19,11 +19,13 @@ from homeassistant.core import (
)
from homeassistant.exceptions import ServiceValidationError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.trigger_template_entity import ValueTemplate
from homeassistant.util.json import JsonValueType
from .const import CONF_QUERY, DOMAIN
from .util import (
async_create_sessionmaker,
check_and_render_sql_query,
convert_value,
generate_lambda_stmt,
redact_credentials,
@@ -37,7 +39,9 @@ _LOGGER = logging.getLogger(__name__)
SERVICE_QUERY = "query"
SERVICE_QUERY_SCHEMA = vol.Schema(
{
vol.Required(CONF_QUERY): vol.All(cv.string, validate_sql_select),
vol.Required(CONF_QUERY): vol.All(
cv.template, ValueTemplate.from_template, validate_sql_select
),
vol.Optional(CONF_DB_URL): cv.string,
}
)
@@ -72,8 +76,9 @@ async def _async_query_service(
def _execute_and_convert_query() -> list[JsonValueType]:
"""Execute the query and return the results with converted types."""
sess: Session = sessmaker()
rendered_query = check_and_render_sql_query(call.hass, query_str)
try:
result: Result = sess.execute(generate_lambda_stmt(query_str))
result: Result = sess.execute(generate_lambda_stmt(rendered_query))
except SQLAlchemyError as err:
_LOGGER.debug(
"Error executing query %s: %s",

View File

@@ -8,7 +8,7 @@
"db_url_invalid": "Database URL invalid",
"multiple_queries": "Multiple SQL queries are not supported",
"query_invalid": "SQL query invalid",
"query_no_read_only": "SQL query must be read-only"
"query_no_read_only": "SQL query is not a read-only SELECT query or it's of an unknown type"
},
"step": {
"options": {

View File

@@ -19,7 +19,9 @@ import voluptuous as vol
from homeassistant.components.recorder import SupportedDialect, get_instance
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import issue_registry as ir
from homeassistant.helpers.template import Template
from .const import DB_URL_RE, DOMAIN
from .models import SQLData
@@ -44,16 +46,14 @@ def resolve_db_url(hass: HomeAssistant, db_url: str | None) -> str:
return get_instance(hass).db_url
def validate_sql_select(value: str) -> str:
def validate_sql_select(value: Template) -> Template:
"""Validate that value is a SQL SELECT query."""
if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1:
raise vol.Invalid("Multiple SQL queries are not supported")
if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN":
raise vol.Invalid("Invalid SQL query")
if query_type != "SELECT":
_LOGGER.debug("The SQL query %s is of type %s", query, query_type)
raise vol.Invalid("Only SELECT queries allowed")
return str(query[0])
try:
assert value.hass
check_and_render_sql_query(value.hass, value)
except (TemplateError, InvalidSqlQuery) as err:
raise vol.Invalid(str(err)) from err
return value
async def async_create_sessionmaker(
@@ -113,7 +113,7 @@ async def async_create_sessionmaker(
def validate_query(
hass: HomeAssistant,
query_str: str,
query_template: str | Template,
uses_recorder_db: bool,
unique_id: str | None = None,
) -> None:
@@ -121,7 +121,7 @@ def validate_query(
Args:
hass: The Home Assistant instance.
query_str: The SQL query string to be validated.
query_template: The SQL query string to be validated.
uses_recorder_db: A boolean indicating if the query is against the recorder database.
unique_id: The unique ID of the entity, used for creating issue registry keys.
@@ -131,6 +131,10 @@ def validate_query(
"""
if not uses_recorder_db:
return
if isinstance(query_template, Template):
query_str = query_template.async_render()
else:
query_str = Template(query_template, hass).async_render()
redacted_query = redact_credentials(query_str)
issue_key = unique_id if unique_id else redacted_query
@@ -239,3 +243,49 @@ def convert_value(value: Any) -> Any:
return f"0x{value.hex()}"
case _:
return value
def check_and_render_sql_query(hass: HomeAssistant, query: Template | str) -> str:
"""Check and render SQL query."""
if isinstance(query, str):
query = query.strip()
if not query:
raise EmptyQueryError("Query cannot be empty")
query = Template(query, hass=hass)
# Raises TemplateError if template is invalid
query.ensure_valid()
rendered_query: str = query.async_render()
if len(rendered_queries := sqlparse.parse(rendered_query.lstrip().lstrip(";"))) > 1:
raise MultipleQueryError("Multiple SQL statements are not allowed")
if (
len(rendered_queries) == 0
or (query_type := rendered_queries[0].get_type()) == "UNKNOWN"
):
raise UnknownQueryTypeError("SQL query is empty or unknown type")
if query_type != "SELECT":
_LOGGER.debug("The SQL query %s is of type %s", rendered_query, query_type)
raise NotSelectQueryError("SQL query must be of type SELECT")
return str(rendered_queries[0])
class InvalidSqlQuery(HomeAssistantError):
"""SQL query is invalid error."""
class EmptyQueryError(InvalidSqlQuery):
"""SQL query is empty error."""
class MultipleQueryError(InvalidSqlQuery):
"""SQL query is multiple error."""
class UnknownQueryTypeError(InvalidSqlQuery):
"""SQL query is of unknown type error."""
class NotSelectQueryError(InvalidSqlQuery):
"""SQL query is not a SELECT statement error."""

View File

@@ -44,6 +44,17 @@ ENTRY_CONFIG = {
},
}
ENTRY_CONFIG_BLANK_QUERY = {
CONF_NAME: "Get Value",
CONF_QUERY: " ",
CONF_COLUMN_NAME: "value",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
},
}
ENTRY_CONFIG_WITH_VALUE_TEMPLATE = {
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
@@ -53,6 +64,33 @@ ENTRY_CONFIG_WITH_VALUE_TEMPLATE = {
},
}
ENTRY_CONFIG_WITH_QUERY_TEMPLATE = {
CONF_QUERY: "SELECT {% if states('sensor.input1')=='on' %} 5 {% else %} 6 {% endif %} as value",
CONF_COLUMN_NAME: "value",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_VALUE_TEMPLATE: "{{ value }}",
},
}
ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE = {
CONF_QUERY: "SELECT {{ 5 as value",
CONF_COLUMN_NAME: "value",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_VALUE_TEMPLATE: "{{ value }}",
},
}
ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE_OPT = {
CONF_QUERY: "SELECT {{ 5 as value",
CONF_COLUMN_NAME: "value",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_VALUE_TEMPLATE: "{{ value }}",
},
}
ENTRY_CONFIG_INVALID_QUERY = {
CONF_QUERY: "SELECT 5 FROM as value",
CONF_COLUMN_NAME: "size",

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from pathlib import Path
import re
from typing import Any
from unittest.mock import patch
@@ -10,7 +11,7 @@ import pytest
from sqlalchemy.exc import SQLAlchemyError
from homeassistant import config_entries
from homeassistant.components.recorder import CONF_DB_URL
from homeassistant.components.recorder import CONF_DB_URL, Recorder
from homeassistant.components.sensor import (
CONF_STATE_CLASS,
SensorDeviceClass,
@@ -29,7 +30,7 @@ from homeassistant.const import (
CONF_VALUE_TEMPLATE,
)
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from homeassistant.data_entry_flow import FlowResultType, InvalidData
from . import (
ENTRY_CONFIG,
@@ -48,6 +49,9 @@ from . import (
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE,
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT,
ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT,
ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE,
ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE_OPT,
ENTRY_CONFIG_WITH_QUERY_TEMPLATE,
ENTRY_CONFIG_WITH_VALUE_TEMPLATE,
)
@@ -106,7 +110,91 @@ async def test_form_simple(
}
async def test_form_with_value_template(hass: HomeAssistant) -> None:
async def test_form_with_query_template(
recorder_mock: Recorder, hass: HomeAssistant
) -> None:
"""Test for with query template."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {}
with patch(
"homeassistant.components.sql.async_setup_entry",
return_value=True,
) as mock_setup_entry:
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
DATA_CONFIG,
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
ENTRY_CONFIG_WITH_QUERY_TEMPLATE,
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == "Get Value"
assert result["options"] == {
CONF_QUERY: "SELECT {% if states('sensor.input1')=='on' %} 5 {% else %} 6 {% endif %} as value",
CONF_COLUMN_NAME: "value",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_VALUE_TEMPLATE: "{{ value }}",
},
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_form_with_broken_query_template(
recorder_mock: Recorder, hass: HomeAssistant
) -> None:
"""Test form with broken query template."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {}
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
DATA_CONFIG,
)
message = re.escape("Schema validation failed @ data['query']")
with pytest.raises(InvalidData, match=message):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE,
)
with patch(
"homeassistant.components.sql.async_setup_entry",
return_value=True,
) as mock_setup_entry:
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
ENTRY_CONFIG_WITH_QUERY_TEMPLATE,
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == "Get Value"
assert result["options"] == {
CONF_QUERY: "SELECT {% if states('sensor.input1')=='on' %} 5 {% else %} 6 {% endif %} as value",
CONF_COLUMN_NAME: "value",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_VALUE_TEMPLATE: "{{ value }}",
},
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_form_with_value_template(
recorder_mock: Recorder, hass: HomeAssistant
) -> None:
"""Test for with value template."""
result = await hass.config_entries.flow.async_init(
@@ -192,7 +280,7 @@ async def test_flow_fails_invalid_query(hass: HomeAssistant) -> None:
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {
CONF_QUERY: "query_invalid",
CONF_QUERY: "query_no_read_only",
}
result = await hass.config_entries.flow.async_configure(
@@ -202,7 +290,7 @@ async def test_flow_fails_invalid_query(hass: HomeAssistant) -> None:
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {
CONF_QUERY: "query_invalid",
CONF_QUERY: "query_no_read_only",
}
result = await hass.config_entries.flow.async_configure(
@@ -484,7 +572,7 @@ async def test_options_flow_fails_invalid_query(hass: HomeAssistant) -> None:
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {
CONF_QUERY: "query_invalid",
CONF_QUERY: "query_no_read_only",
}
result = await hass.config_entries.options.async_configure(
@@ -494,9 +582,8 @@ async def test_options_flow_fails_invalid_query(hass: HomeAssistant) -> None:
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {
CONF_QUERY: "query_invalid",
CONF_QUERY: "query_no_read_only",
}
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT,
@@ -527,6 +614,13 @@ async def test_options_flow_fails_invalid_query(hass: HomeAssistant) -> None:
CONF_QUERY: "multiple_queries",
}
message = re.escape("Schema validation failed @ data['query']")
with pytest.raises(InvalidData, match=message):
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input=ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE_OPT,
)
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={

View File

@@ -4,6 +4,9 @@ from __future__ import annotations
from unittest.mock import patch
import pytest
import voluptuous as vol
from homeassistant.components.recorder import CONF_DB_URL, Recorder
from homeassistant.components.sensor import (
CONF_STATE_CLASS,
@@ -16,6 +19,7 @@ from homeassistant.components.sql.const import (
CONF_QUERY,
DOMAIN,
)
from homeassistant.components.sql.util import validate_sql_select
from homeassistant.config_entries import SOURCE_USER, ConfigEntryState
from homeassistant.const import (
CONF_DEVICE_CLASS,
@@ -24,6 +28,7 @@ from homeassistant.const import (
CONF_VALUE_TEMPLATE,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers.template import Template
from homeassistant.setup import async_setup_component
from . import YAML_CONFIG_INVALID, YAML_CONFIG_NO_DB, init_integration
@@ -67,6 +72,45 @@ async def test_setup_invalid_config(
await hass.async_block_till_done()
async def test_invalid_query(hass: HomeAssistant) -> None:
"""Test invalid query."""
with pytest.raises(vol.Invalid, match="SQL query must be of type SELECT"):
validate_sql_select(Template("DROP TABLE *", hass))
with pytest.raises(vol.Invalid, match="SQL query is empty or unknown type"):
validate_sql_select(Template("SELECT5 as value", hass))
with pytest.raises(vol.Invalid, match="SQL query is empty or unknown type"):
validate_sql_select(Template(";;", hass))
async def test_query_no_read_only(hass: HomeAssistant) -> None:
"""Test query no read only."""
with pytest.raises(vol.Invalid, match="SQL query must be of type SELECT"):
validate_sql_select(
Template("UPDATE states SET state = 999999 WHERE state_id = 11125", hass)
)
async def test_query_no_read_only_cte(hass: HomeAssistant) -> None:
"""Test query no read only CTE."""
with pytest.raises(vol.Invalid, match="SQL query must be of type SELECT"):
validate_sql_select(
Template(
"WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
hass,
)
)
async def test_multiple_queries(hass: HomeAssistant) -> None:
"""Test multiple queries."""
with pytest.raises(vol.Invalid, match="Multiple SQL statements are not allowed"):
validate_sql_select(
Template("SELECT 5 as value; UPDATE states SET state = 10;", hass)
)
async def test_migration_from_future(
recorder_mock: Recorder, hass: HomeAssistant
) -> None:

View File

@@ -39,7 +39,6 @@ from homeassistant.const import (
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import issue_registry as ir
from homeassistant.helpers.entity_platform import async_get_platforms
from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
@@ -109,6 +108,33 @@ async def test_query_value_template(
}
async def test_template_query(
recorder_mock: Recorder,
hass: HomeAssistant,
freezer: FrozenDateTimeFactory,
) -> None:
"""Test the SQL sensor with a query template."""
options = {
CONF_QUERY: "SELECT {% if states('sensor.input1')=='on' %} 5 {% else %} 6 {% endif %} as value",
CONF_COLUMN_NAME: "value",
CONF_ADVANCED_OPTIONS: {
CONF_VALUE_TEMPLATE: "{{ value | int }}",
},
}
await init_integration(hass, title="count_tables", options=options)
state = hass.states.get("sensor.count_tables")
assert state.state == "6"
hass.states.async_set("sensor.input1", "on")
freezer.tick(timedelta(minutes=1))
async_fire_time_changed(hass)
await hass.async_block_till_done(wait_background_tasks=True)
state = hass.states.get("sensor.count_tables")
assert state.state == "5"
async def test_query_value_template_invalid(
recorder_mock: Recorder, hass: HomeAssistant
) -> None:
@@ -124,6 +150,59 @@ async def test_query_value_template_invalid(
assert state.state == "5.01"
async def test_broken_template_query(
recorder_mock: Recorder,
hass: HomeAssistant,
freezer: FrozenDateTimeFactory,
) -> None:
"""Test the SQL sensor with a query template which is broken."""
options = {
CONF_QUERY: "SELECT {{ 5 as value",
CONF_COLUMN_NAME: "value",
CONF_ADVANCED_OPTIONS: {
CONF_VALUE_TEMPLATE: "{{ value | int }}",
},
}
await init_integration(hass, title="count_tables", options=options)
state = hass.states.get("sensor.count_tables")
assert not state
async def test_broken_template_query_2(
recorder_mock: Recorder,
hass: HomeAssistant,
freezer: FrozenDateTimeFactory,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test the SQL sensor with a query template."""
hass.states.async_set("sensor.input1", "5")
await hass.async_block_till_done(wait_background_tasks=True)
options = {
CONF_QUERY: "SELECT {{ states.sensor.input1.state | int / 1000}} as value",
CONF_COLUMN_NAME: "value",
}
await init_integration(hass, title="count_tables", options=options)
state = hass.states.get("sensor.count_tables")
assert state.state == "0.005"
hass.states.async_set("sensor.input1", "on")
freezer.tick(timedelta(minutes=1))
async_fire_time_changed(hass)
await hass.async_block_till_done(wait_background_tasks=True)
state = hass.states.get("sensor.count_tables")
assert state.state == "0.005"
assert (
"Error rendering query SELECT {{ states.sensor.input1.state | int / 1000}} as value"
" LIMIT 1;: ValueError: Template error: int got invalid input 'on' when rendering"
" template 'SELECT {{ states.sensor.input1.state | int / 1000}} as value LIMIT 1;'"
" but no default was specified" in caplog.text
)
async def test_query_limit(recorder_mock: Recorder, hass: HomeAssistant) -> None:
"""Test the SQL sensor with a query containing 'LIMIT' in lowercase."""
options = {
@@ -641,17 +720,14 @@ async def test_query_recover_from_rollback(
CONF_UNIQUE_ID: "very_unique_id",
}
await init_integration(hass, title="Select value SQL query", options=options)
platforms = async_get_platforms(hass, "sql")
sql_entity = platforms[0].entities["sensor.select_value_sql_query"]
state = hass.states.get("sensor.select_value_sql_query")
assert state.state == "5"
assert state.attributes["value"] == 5
with patch.object(
sql_entity,
"_lambda_stmt",
generate_lambda_stmt("Faulty syntax create operational issue"),
with patch(
"homeassistant.components.sql.sensor.generate_lambda_stmt",
return_value=generate_lambda_stmt("Faulty syntax create operational issue"),
):
freezer.tick(timedelta(minutes=1))
async_fire_time_changed(hass)

View File

@@ -153,7 +153,7 @@ async def test_query_service_invalid_query_not_select(
await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
with pytest.raises(vol.Invalid, match="Only SELECT queries allowed"):
with pytest.raises(vol.Invalid, match="SQL query must be of type SELECT"):
await hass.services.async_call(
DOMAIN,
SERVICE_QUERY,
@@ -171,7 +171,7 @@ async def test_query_service_sqlalchemy_error(
await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
with pytest.raises(MultipleInvalid, match="Invalid SQL query"):
with pytest.raises(MultipleInvalid, match="SQL query is empty or unknown type"):
await hass.services.async_call(
DOMAIN,
SERVICE_QUERY,

View File

@@ -13,6 +13,7 @@ from homeassistant.components.sql.util import (
validate_sql_select,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers.template import Template
async def test_resolve_db_url_when_none_configured(
@@ -39,27 +40,27 @@ async def test_resolve_db_url_when_configured(hass: HomeAssistant) -> None:
[
(
"DROP TABLE *",
"Only SELECT queries allowed",
"SQL query must be of type SELECT",
),
(
"SELECT5 as value",
"Invalid SQL query",
"SQL query is empty or unknown type",
),
(
";;",
"Invalid SQL query",
"SQL query is empty or unknown type",
),
(
"UPDATE states SET state = 999999 WHERE state_id = 11125",
"Only SELECT queries allowed",
"SQL query must be of type SELECT",
),
(
"WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
"Only SELECT queries allowed",
"SQL query must be of type SELECT",
),
(
"SELECT 5 as value; UPDATE states SET state = 10;",
"Multiple SQL queries are not supported",
"Multiple SQL statements are not allowed",
),
],
)
@@ -70,7 +71,7 @@ async def test_invalid_sql_queries(
) -> None:
"""Test that various invalid or disallowed SQL queries raise the correct exception."""
with pytest.raises(vol.Invalid, match=expected_error_message):
validate_sql_select(sql_query)
validate_sql_select(Template(sql_query, hass))
@pytest.mark.parametrize(