From 67ccdd36fbe5cd5779b4afe5ad220f207e2bc951 Mon Sep 17 00:00:00 2001 From: G Johansson Date: Thu, 6 Nov 2025 17:11:46 +0100 Subject: [PATCH] Allow template in query in sql (#150287) --- homeassistant/components/sql/__init__.py | 4 +- homeassistant/components/sql/config_flow.py | 47 +++++---- homeassistant/components/sql/sensor.py | 52 ++++++--- homeassistant/components/sql/services.py | 9 +- homeassistant/components/sql/strings.json | 2 +- homeassistant/components/sql/util.py | 72 +++++++++++-- tests/components/sql/__init__.py | 38 +++++++ tests/components/sql/test_config_flow.py | 110 ++++++++++++++++++-- tests/components/sql/test_init.py | 44 ++++++++ tests/components/sql/test_sensor.py | 90 ++++++++++++++-- tests/components/sql/test_services.py | 4 +- tests/components/sql/test_util.py | 15 +-- 12 files changed, 411 insertions(+), 76 deletions(-) diff --git a/homeassistant/components/sql/__init__.py b/homeassistant/components/sql/__init__.py index d658c81be1c..aac9b47b0d4 100644 --- a/homeassistant/components/sql/__init__.py +++ b/homeassistant/components/sql/__init__.py @@ -49,7 +49,9 @@ QUERY_SCHEMA = vol.Schema( { vol.Required(CONF_COLUMN_NAME): cv.string, vol.Required(CONF_NAME): cv.template, - vol.Required(CONF_QUERY): vol.All(cv.string, validate_sql_select), + vol.Required(CONF_QUERY): vol.All( + cv.template, ValueTemplate.from_template, validate_sql_select + ), vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string, vol.Optional(CONF_VALUE_TEMPLATE): vol.All( cv.template, ValueTemplate.from_template diff --git a/homeassistant/components/sql/config_flow.py b/homeassistant/components/sql/config_flow.py index a614105d8bc..6c0fcfb11a4 100644 --- a/homeassistant/components/sql/config_flow.py +++ b/homeassistant/components/sql/config_flow.py @@ -9,8 +9,6 @@ import sqlalchemy from sqlalchemy.engine import Engine, Result from sqlalchemy.exc import MultipleResultsFound, NoSuchColumnError, SQLAlchemyError from sqlalchemy.orm import Session, scoped_session, sessionmaker -import sqlparse -from sqlparse.exceptions import SQLParseError import voluptuous as vol from homeassistant.components.recorder import CONF_DB_URL, get_instance @@ -31,21 +29,28 @@ from homeassistant.const import ( CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE, ) -from homeassistant.core import callback +from homeassistant.core import async_get_hass, callback from homeassistant.data_entry_flow import section +from homeassistant.exceptions import TemplateError from homeassistant.helpers import selector from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN -from .util import resolve_db_url +from .util import ( + EmptyQueryError, + InvalidSqlQuery, + MultipleQueryError, + NotSelectQueryError, + UnknownQueryTypeError, + check_and_render_sql_query, + resolve_db_url, +) _LOGGER = logging.getLogger(__name__) OPTIONS_SCHEMA: vol.Schema = vol.Schema( { - vol.Required(CONF_QUERY): selector.TextSelector( - selector.TextSelectorConfig(multiline=True) - ), + vol.Required(CONF_QUERY): selector.TemplateSelector(), vol.Required(CONF_COLUMN_NAME): selector.TextSelector(), vol.Required(CONF_ADVANCED_OPTIONS): section( vol.Schema( @@ -89,14 +94,12 @@ CONFIG_SCHEMA: vol.Schema = vol.Schema( def validate_sql_select(value: str) -> str: """Validate that value is a SQL SELECT query.""" - if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1: - raise MultipleResultsFound - if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN": - raise ValueError - if query_type != "SELECT": - _LOGGER.debug("The SQL query %s is of type %s", query, query_type) - raise SQLParseError - return str(query[0]) + hass = async_get_hass() + try: + return check_and_render_sql_query(hass, value) + except (TemplateError, InvalidSqlQuery) as err: + _LOGGER.debug("Invalid query '%s' results in '%s'", value, err.args[0]) + raise def validate_db_connection(db_url: str) -> bool: @@ -138,7 +141,7 @@ def validate_query(db_url: str, query: str, column: str) -> bool: if sess: sess.close() engine.dispose() - raise ValueError(error) from error + raise InvalidSqlQuery from error for res in result.mappings(): if column not in res: @@ -224,13 +227,13 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN): except NoSuchColumnError: errors["column"] = "column_invalid" description_placeholders = {"column": column} - except MultipleResultsFound: + except (MultipleResultsFound, MultipleQueryError): errors["query"] = "multiple_queries" except SQLAlchemyError: errors["db_url"] = "db_url_invalid" - except SQLParseError: + except (NotSelectQueryError, UnknownQueryTypeError): errors["query"] = "query_no_read_only" - except ValueError as err: + except (TemplateError, EmptyQueryError, InvalidSqlQuery) as err: _LOGGER.debug("Invalid query: %s", err) errors["query"] = "query_invalid" @@ -282,13 +285,13 @@ class SQLOptionsFlowHandler(OptionsFlowWithReload): except NoSuchColumnError: errors["column"] = "column_invalid" description_placeholders = {"column": column} - except MultipleResultsFound: + except (MultipleResultsFound, MultipleQueryError): errors["query"] = "multiple_queries" except SQLAlchemyError: errors["db_url"] = "db_url_invalid" - except SQLParseError: + except (NotSelectQueryError, UnknownQueryTypeError): errors["query"] = "query_no_read_only" - except ValueError as err: + except (TemplateError, EmptyQueryError, InvalidSqlQuery) as err: _LOGGER.debug("Invalid query: %s", err) errors["query"] = "query_invalid" else: diff --git a/homeassistant/components/sql/sensor.py b/homeassistant/components/sql/sensor.py index c8885cd2377..dddd1386932 100644 --- a/homeassistant/components/sql/sensor.py +++ b/homeassistant/components/sql/sensor.py @@ -22,7 +22,7 @@ from homeassistant.const import ( MATCH_ALL, ) from homeassistant.core import HomeAssistant -from homeassistant.exceptions import TemplateError +from homeassistant.exceptions import PlatformNotReady, TemplateError from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.entity_platform import ( AddConfigEntryEntitiesCallback, @@ -40,7 +40,9 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN from .util import ( + InvalidSqlQuery, async_create_sessionmaker, + check_and_render_sql_query, convert_value, generate_lambda_stmt, redact_credentials, @@ -81,7 +83,7 @@ async def async_setup_platform( return name: Template = conf[CONF_NAME] - query_str: str = conf[CONF_QUERY] + query_template: ValueTemplate = conf[CONF_QUERY] value_template: ValueTemplate | None = conf.get(CONF_VALUE_TEMPLATE) column_name: str = conf[CONF_COLUMN_NAME] unique_id: str | None = conf.get(CONF_UNIQUE_ID) @@ -96,7 +98,7 @@ async def async_setup_platform( await async_setup_sensor( hass, trigger_entity_config, - query_str, + query_template, column_name, value_template, unique_id, @@ -119,6 +121,13 @@ async def async_setup_entry( template: str | None = entry.options[CONF_ADVANCED_OPTIONS].get(CONF_VALUE_TEMPLATE) column_name: str = entry.options[CONF_COLUMN_NAME] + query_template: ValueTemplate | None = None + try: + query_template = ValueTemplate(query_str, hass) + query_template.ensure_valid() + except TemplateError as err: + raise PlatformNotReady("Invalid SQL query template") from err + value_template: ValueTemplate | None = None if template is not None: try: @@ -137,7 +146,7 @@ async def async_setup_entry( await async_setup_sensor( hass, trigger_entity_config, - query_str, + query_template, column_name, value_template, entry.entry_id, @@ -150,7 +159,7 @@ async def async_setup_entry( async def async_setup_sensor( hass: HomeAssistant, trigger_entity_config: ConfigType, - query_str: str, + query_template: ValueTemplate, column_name: str, value_template: ValueTemplate | None, unique_id: str | None, @@ -166,22 +175,25 @@ async def async_setup_sensor( ) = await async_create_sessionmaker(hass, db_url) if sessmaker is None: return - validate_query(hass, query_str, uses_recorder_db, unique_id) + validate_query(hass, query_template, uses_recorder_db, unique_id) + query_str = check_and_render_sql_query(hass, query_template) upper_query = query_str.upper() # MSSQL uses TOP and not LIMIT + mod_query_template = query_template if not ("LIMIT" in upper_query or "SELECT TOP" in upper_query): if "mssql" in db_url: - query_str = upper_query.replace("SELECT", "SELECT TOP 1") + _query = query_template.template.replace("SELECT", "SELECT TOP 1") else: - query_str = query_str.replace(";", "") + " LIMIT 1;" + _query = query_template.template.replace(";", "") + " LIMIT 1;" + mod_query_template = ValueTemplate(_query, hass) async_add_entities( [ SQLSensor( trigger_entity_config, sessmaker, - query_str, + mod_query_template, column_name, value_template, yaml, @@ -200,7 +212,7 @@ class SQLSensor(ManualTriggerSensorEntity): self, trigger_entity_config: ConfigType, sessmaker: scoped_session, - query: str, + query: ValueTemplate, column: str, value_template: ValueTemplate | None, yaml: bool, @@ -214,7 +226,6 @@ class SQLSensor(ManualTriggerSensorEntity): self.sessionmaker = sessmaker self._attr_extra_state_attributes = {} self._use_database_executor = use_database_executor - self._lambda_stmt = generate_lambda_stmt(query) if not yaml and (unique_id := trigger_entity_config.get(CONF_UNIQUE_ID)): self._attr_name = None self._attr_has_entity_name = True @@ -255,11 +266,22 @@ class SQLSensor(ManualTriggerSensorEntity): self._attr_extra_state_attributes = {} sess: scoped_session = self.sessionmaker() try: - result: Result = sess.execute(self._lambda_stmt) + rendered_query = check_and_render_sql_query(self.hass, self._query) + _lambda_stmt = generate_lambda_stmt(rendered_query) + result: Result = sess.execute(_lambda_stmt) + except (TemplateError, InvalidSqlQuery) as err: + _LOGGER.error( + "Error rendering query %s: %s", + redact_credentials(self._query.template), + redact_credentials(str(err)), + ) + sess.rollback() + sess.close() + return except SQLAlchemyError as err: _LOGGER.error( "Error executing query %s: %s", - self._query, + rendered_query, redact_credentials(str(err)), ) sess.rollback() @@ -267,7 +289,7 @@ class SQLSensor(ManualTriggerSensorEntity): return for res in result.mappings(): - _LOGGER.debug("Query %s result in %s", self._query, res.items()) + _LOGGER.debug("Query %s result in %s", rendered_query, res.items()) data = res[self._column_name] for key, value in res.items(): self._attr_extra_state_attributes[key] = convert_value(value) @@ -287,6 +309,6 @@ class SQLSensor(ManualTriggerSensorEntity): self._attr_native_value = data if data is None: - _LOGGER.warning("%s returned no results", self._query) + _LOGGER.warning("%s returned no results", rendered_query) sess.close() diff --git a/homeassistant/components/sql/services.py b/homeassistant/components/sql/services.py index dc31064d3ec..6ab97a2e665 100644 --- a/homeassistant/components/sql/services.py +++ b/homeassistant/components/sql/services.py @@ -19,11 +19,13 @@ from homeassistant.core import ( ) from homeassistant.exceptions import ServiceValidationError from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.trigger_template_entity import ValueTemplate from homeassistant.util.json import JsonValueType from .const import CONF_QUERY, DOMAIN from .util import ( async_create_sessionmaker, + check_and_render_sql_query, convert_value, generate_lambda_stmt, redact_credentials, @@ -37,7 +39,9 @@ _LOGGER = logging.getLogger(__name__) SERVICE_QUERY = "query" SERVICE_QUERY_SCHEMA = vol.Schema( { - vol.Required(CONF_QUERY): vol.All(cv.string, validate_sql_select), + vol.Required(CONF_QUERY): vol.All( + cv.template, ValueTemplate.from_template, validate_sql_select + ), vol.Optional(CONF_DB_URL): cv.string, } ) @@ -72,8 +76,9 @@ async def _async_query_service( def _execute_and_convert_query() -> list[JsonValueType]: """Execute the query and return the results with converted types.""" sess: Session = sessmaker() + rendered_query = check_and_render_sql_query(call.hass, query_str) try: - result: Result = sess.execute(generate_lambda_stmt(query_str)) + result: Result = sess.execute(generate_lambda_stmt(rendered_query)) except SQLAlchemyError as err: _LOGGER.debug( "Error executing query %s: %s", diff --git a/homeassistant/components/sql/strings.json b/homeassistant/components/sql/strings.json index 2a2cb6ab47f..00f8c1fc815 100644 --- a/homeassistant/components/sql/strings.json +++ b/homeassistant/components/sql/strings.json @@ -8,7 +8,7 @@ "db_url_invalid": "Database URL invalid", "multiple_queries": "Multiple SQL queries are not supported", "query_invalid": "SQL query invalid", - "query_no_read_only": "SQL query must be read-only" + "query_no_read_only": "SQL query is not a read-only SELECT query or it's of an unknown type" }, "step": { "options": { diff --git a/homeassistant/components/sql/util.py b/homeassistant/components/sql/util.py index cc6f1bb5ea1..f5b49187ba8 100644 --- a/homeassistant/components/sql/util.py +++ b/homeassistant/components/sql/util.py @@ -19,7 +19,9 @@ import voluptuous as vol from homeassistant.components.recorder import SupportedDialect, get_instance from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.helpers import issue_registry as ir +from homeassistant.helpers.template import Template from .const import DB_URL_RE, DOMAIN from .models import SQLData @@ -44,16 +46,14 @@ def resolve_db_url(hass: HomeAssistant, db_url: str | None) -> str: return get_instance(hass).db_url -def validate_sql_select(value: str) -> str: +def validate_sql_select(value: Template) -> Template: """Validate that value is a SQL SELECT query.""" - if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1: - raise vol.Invalid("Multiple SQL queries are not supported") - if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN": - raise vol.Invalid("Invalid SQL query") - if query_type != "SELECT": - _LOGGER.debug("The SQL query %s is of type %s", query, query_type) - raise vol.Invalid("Only SELECT queries allowed") - return str(query[0]) + try: + assert value.hass + check_and_render_sql_query(value.hass, value) + except (TemplateError, InvalidSqlQuery) as err: + raise vol.Invalid(str(err)) from err + return value async def async_create_sessionmaker( @@ -113,7 +113,7 @@ async def async_create_sessionmaker( def validate_query( hass: HomeAssistant, - query_str: str, + query_template: str | Template, uses_recorder_db: bool, unique_id: str | None = None, ) -> None: @@ -121,7 +121,7 @@ def validate_query( Args: hass: The Home Assistant instance. - query_str: The SQL query string to be validated. + query_template: The SQL query string to be validated. uses_recorder_db: A boolean indicating if the query is against the recorder database. unique_id: The unique ID of the entity, used for creating issue registry keys. @@ -131,6 +131,10 @@ def validate_query( """ if not uses_recorder_db: return + if isinstance(query_template, Template): + query_str = query_template.async_render() + else: + query_str = Template(query_template, hass).async_render() redacted_query = redact_credentials(query_str) issue_key = unique_id if unique_id else redacted_query @@ -239,3 +243,49 @@ def convert_value(value: Any) -> Any: return f"0x{value.hex()}" case _: return value + + +def check_and_render_sql_query(hass: HomeAssistant, query: Template | str) -> str: + """Check and render SQL query.""" + if isinstance(query, str): + query = query.strip() + if not query: + raise EmptyQueryError("Query cannot be empty") + query = Template(query, hass=hass) + + # Raises TemplateError if template is invalid + query.ensure_valid() + rendered_query: str = query.async_render() + + if len(rendered_queries := sqlparse.parse(rendered_query.lstrip().lstrip(";"))) > 1: + raise MultipleQueryError("Multiple SQL statements are not allowed") + if ( + len(rendered_queries) == 0 + or (query_type := rendered_queries[0].get_type()) == "UNKNOWN" + ): + raise UnknownQueryTypeError("SQL query is empty or unknown type") + if query_type != "SELECT": + _LOGGER.debug("The SQL query %s is of type %s", rendered_query, query_type) + raise NotSelectQueryError("SQL query must be of type SELECT") + + return str(rendered_queries[0]) + + +class InvalidSqlQuery(HomeAssistantError): + """SQL query is invalid error.""" + + +class EmptyQueryError(InvalidSqlQuery): + """SQL query is empty error.""" + + +class MultipleQueryError(InvalidSqlQuery): + """SQL query is multiple error.""" + + +class UnknownQueryTypeError(InvalidSqlQuery): + """SQL query is of unknown type error.""" + + +class NotSelectQueryError(InvalidSqlQuery): + """SQL query is not a SELECT statement error.""" diff --git a/tests/components/sql/__init__.py b/tests/components/sql/__init__.py index 6afc0329e32..c327059278c 100644 --- a/tests/components/sql/__init__.py +++ b/tests/components/sql/__init__.py @@ -44,6 +44,17 @@ ENTRY_CONFIG = { }, } +ENTRY_CONFIG_BLANK_QUERY = { + CONF_NAME: "Get Value", + CONF_QUERY: " ", + CONF_COLUMN_NAME: "value", + CONF_ADVANCED_OPTIONS: { + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE, + CONF_STATE_CLASS: SensorStateClass.TOTAL, + }, +} + ENTRY_CONFIG_WITH_VALUE_TEMPLATE = { CONF_QUERY: "SELECT 5 as value", CONF_COLUMN_NAME: "value", @@ -53,6 +64,33 @@ ENTRY_CONFIG_WITH_VALUE_TEMPLATE = { }, } +ENTRY_CONFIG_WITH_QUERY_TEMPLATE = { + CONF_QUERY: "SELECT {% if states('sensor.input1')=='on' %} 5 {% else %} 6 {% endif %} as value", + CONF_COLUMN_NAME: "value", + CONF_ADVANCED_OPTIONS: { + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_VALUE_TEMPLATE: "{{ value }}", + }, +} + +ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE = { + CONF_QUERY: "SELECT {{ 5 as value", + CONF_COLUMN_NAME: "value", + CONF_ADVANCED_OPTIONS: { + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_VALUE_TEMPLATE: "{{ value }}", + }, +} + +ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE_OPT = { + CONF_QUERY: "SELECT {{ 5 as value", + CONF_COLUMN_NAME: "value", + CONF_ADVANCED_OPTIONS: { + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_VALUE_TEMPLATE: "{{ value }}", + }, +} + ENTRY_CONFIG_INVALID_QUERY = { CONF_QUERY: "SELECT 5 FROM as value", CONF_COLUMN_NAME: "size", diff --git a/tests/components/sql/test_config_flow.py b/tests/components/sql/test_config_flow.py index 863e87b5eae..d39f28dba82 100644 --- a/tests/components/sql/test_config_flow.py +++ b/tests/components/sql/test_config_flow.py @@ -3,6 +3,7 @@ from __future__ import annotations from pathlib import Path +import re from typing import Any from unittest.mock import patch @@ -10,7 +11,7 @@ import pytest from sqlalchemy.exc import SQLAlchemyError from homeassistant import config_entries -from homeassistant.components.recorder import CONF_DB_URL +from homeassistant.components.recorder import CONF_DB_URL, Recorder from homeassistant.components.sensor import ( CONF_STATE_CLASS, SensorDeviceClass, @@ -29,7 +30,7 @@ from homeassistant.const import ( CONF_VALUE_TEMPLATE, ) from homeassistant.core import HomeAssistant -from homeassistant.data_entry_flow import FlowResultType +from homeassistant.data_entry_flow import FlowResultType, InvalidData from . import ( ENTRY_CONFIG, @@ -48,6 +49,9 @@ from . import ( ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE, ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT, ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT, + ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE, + ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE_OPT, + ENTRY_CONFIG_WITH_QUERY_TEMPLATE, ENTRY_CONFIG_WITH_VALUE_TEMPLATE, ) @@ -106,7 +110,91 @@ async def test_form_simple( } -async def test_form_with_value_template(hass: HomeAssistant) -> None: +async def test_form_with_query_template( + recorder_mock: Recorder, hass: HomeAssistant +) -> None: + """Test for with query template.""" + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + assert result["errors"] == {} + + with patch( + "homeassistant.components.sql.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + DATA_CONFIG, + ) + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + ENTRY_CONFIG_WITH_QUERY_TEMPLATE, + ) + await hass.async_block_till_done() + + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["title"] == "Get Value" + assert result["options"] == { + CONF_QUERY: "SELECT {% if states('sensor.input1')=='on' %} 5 {% else %} 6 {% endif %} as value", + CONF_COLUMN_NAME: "value", + CONF_ADVANCED_OPTIONS: { + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_VALUE_TEMPLATE: "{{ value }}", + }, + } + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_form_with_broken_query_template( + recorder_mock: Recorder, hass: HomeAssistant +) -> None: + """Test form with broken query template.""" + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + assert result["errors"] == {} + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + DATA_CONFIG, + ) + message = re.escape("Schema validation failed @ data['query']") + with pytest.raises(InvalidData, match=message): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE, + ) + + with patch( + "homeassistant.components.sql.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + ENTRY_CONFIG_WITH_QUERY_TEMPLATE, + ) + + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["title"] == "Get Value" + assert result["options"] == { + CONF_QUERY: "SELECT {% if states('sensor.input1')=='on' %} 5 {% else %} 6 {% endif %} as value", + CONF_COLUMN_NAME: "value", + CONF_ADVANCED_OPTIONS: { + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_VALUE_TEMPLATE: "{{ value }}", + }, + } + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_form_with_value_template( + recorder_mock: Recorder, hass: HomeAssistant +) -> None: """Test for with value template.""" result = await hass.config_entries.flow.async_init( @@ -192,7 +280,7 @@ async def test_flow_fails_invalid_query(hass: HomeAssistant) -> None: assert result["type"] is FlowResultType.FORM assert result["errors"] == { - CONF_QUERY: "query_invalid", + CONF_QUERY: "query_no_read_only", } result = await hass.config_entries.flow.async_configure( @@ -202,7 +290,7 @@ async def test_flow_fails_invalid_query(hass: HomeAssistant) -> None: assert result["type"] is FlowResultType.FORM assert result["errors"] == { - CONF_QUERY: "query_invalid", + CONF_QUERY: "query_no_read_only", } result = await hass.config_entries.flow.async_configure( @@ -484,7 +572,7 @@ async def test_options_flow_fails_invalid_query(hass: HomeAssistant) -> None: assert result["type"] is FlowResultType.FORM assert result["errors"] == { - CONF_QUERY: "query_invalid", + CONF_QUERY: "query_no_read_only", } result = await hass.config_entries.options.async_configure( @@ -494,9 +582,8 @@ async def test_options_flow_fails_invalid_query(hass: HomeAssistant) -> None: assert result["type"] is FlowResultType.FORM assert result["errors"] == { - CONF_QUERY: "query_invalid", + CONF_QUERY: "query_no_read_only", } - result = await hass.config_entries.options.async_configure( result["flow_id"], user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT, @@ -527,6 +614,13 @@ async def test_options_flow_fails_invalid_query(hass: HomeAssistant) -> None: CONF_QUERY: "multiple_queries", } + message = re.escape("Schema validation failed @ data['query']") + with pytest.raises(InvalidData, match=message): + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input=ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE_OPT, + ) + result = await hass.config_entries.options.async_configure( result["flow_id"], user_input={ diff --git a/tests/components/sql/test_init.py b/tests/components/sql/test_init.py index c07d5c9e639..c8d77534235 100644 --- a/tests/components/sql/test_init.py +++ b/tests/components/sql/test_init.py @@ -4,6 +4,9 @@ from __future__ import annotations from unittest.mock import patch +import pytest +import voluptuous as vol + from homeassistant.components.recorder import CONF_DB_URL, Recorder from homeassistant.components.sensor import ( CONF_STATE_CLASS, @@ -16,6 +19,7 @@ from homeassistant.components.sql.const import ( CONF_QUERY, DOMAIN, ) +from homeassistant.components.sql.util import validate_sql_select from homeassistant.config_entries import SOURCE_USER, ConfigEntryState from homeassistant.const import ( CONF_DEVICE_CLASS, @@ -24,6 +28,7 @@ from homeassistant.const import ( CONF_VALUE_TEMPLATE, ) from homeassistant.core import HomeAssistant +from homeassistant.helpers.template import Template from homeassistant.setup import async_setup_component from . import YAML_CONFIG_INVALID, YAML_CONFIG_NO_DB, init_integration @@ -67,6 +72,45 @@ async def test_setup_invalid_config( await hass.async_block_till_done() +async def test_invalid_query(hass: HomeAssistant) -> None: + """Test invalid query.""" + with pytest.raises(vol.Invalid, match="SQL query must be of type SELECT"): + validate_sql_select(Template("DROP TABLE *", hass)) + + with pytest.raises(vol.Invalid, match="SQL query is empty or unknown type"): + validate_sql_select(Template("SELECT5 as value", hass)) + + with pytest.raises(vol.Invalid, match="SQL query is empty or unknown type"): + validate_sql_select(Template(";;", hass)) + + +async def test_query_no_read_only(hass: HomeAssistant) -> None: + """Test query no read only.""" + with pytest.raises(vol.Invalid, match="SQL query must be of type SELECT"): + validate_sql_select( + Template("UPDATE states SET state = 999999 WHERE state_id = 11125", hass) + ) + + +async def test_query_no_read_only_cte(hass: HomeAssistant) -> None: + """Test query no read only CTE.""" + with pytest.raises(vol.Invalid, match="SQL query must be of type SELECT"): + validate_sql_select( + Template( + "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;", + hass, + ) + ) + + +async def test_multiple_queries(hass: HomeAssistant) -> None: + """Test multiple queries.""" + with pytest.raises(vol.Invalid, match="Multiple SQL statements are not allowed"): + validate_sql_select( + Template("SELECT 5 as value; UPDATE states SET state = 10;", hass) + ) + + async def test_migration_from_future( recorder_mock: Recorder, hass: HomeAssistant ) -> None: diff --git a/tests/components/sql/test_sensor.py b/tests/components/sql/test_sensor.py index 73879065999..42ed1a463bd 100644 --- a/tests/components/sql/test_sensor.py +++ b/tests/components/sql/test_sensor.py @@ -39,7 +39,6 @@ from homeassistant.const import ( ) from homeassistant.core import HomeAssistant from homeassistant.helpers import issue_registry as ir -from homeassistant.helpers.entity_platform import async_get_platforms from homeassistant.setup import async_setup_component from homeassistant.util import dt as dt_util @@ -109,6 +108,33 @@ async def test_query_value_template( } +async def test_template_query( + recorder_mock: Recorder, + hass: HomeAssistant, + freezer: FrozenDateTimeFactory, +) -> None: + """Test the SQL sensor with a query template.""" + options = { + CONF_QUERY: "SELECT {% if states('sensor.input1')=='on' %} 5 {% else %} 6 {% endif %} as value", + CONF_COLUMN_NAME: "value", + CONF_ADVANCED_OPTIONS: { + CONF_VALUE_TEMPLATE: "{{ value | int }}", + }, + } + await init_integration(hass, title="count_tables", options=options) + + state = hass.states.get("sensor.count_tables") + assert state.state == "6" + + hass.states.async_set("sensor.input1", "on") + freezer.tick(timedelta(minutes=1)) + async_fire_time_changed(hass) + await hass.async_block_till_done(wait_background_tasks=True) + + state = hass.states.get("sensor.count_tables") + assert state.state == "5" + + async def test_query_value_template_invalid( recorder_mock: Recorder, hass: HomeAssistant ) -> None: @@ -124,6 +150,59 @@ async def test_query_value_template_invalid( assert state.state == "5.01" +async def test_broken_template_query( + recorder_mock: Recorder, + hass: HomeAssistant, + freezer: FrozenDateTimeFactory, +) -> None: + """Test the SQL sensor with a query template which is broken.""" + options = { + CONF_QUERY: "SELECT {{ 5 as value", + CONF_COLUMN_NAME: "value", + CONF_ADVANCED_OPTIONS: { + CONF_VALUE_TEMPLATE: "{{ value | int }}", + }, + } + await init_integration(hass, title="count_tables", options=options) + + state = hass.states.get("sensor.count_tables") + assert not state + + +async def test_broken_template_query_2( + recorder_mock: Recorder, + hass: HomeAssistant, + freezer: FrozenDateTimeFactory, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test the SQL sensor with a query template.""" + hass.states.async_set("sensor.input1", "5") + await hass.async_block_till_done(wait_background_tasks=True) + + options = { + CONF_QUERY: "SELECT {{ states.sensor.input1.state | int / 1000}} as value", + CONF_COLUMN_NAME: "value", + } + await init_integration(hass, title="count_tables", options=options) + + state = hass.states.get("sensor.count_tables") + assert state.state == "0.005" + + hass.states.async_set("sensor.input1", "on") + freezer.tick(timedelta(minutes=1)) + async_fire_time_changed(hass) + await hass.async_block_till_done(wait_background_tasks=True) + + state = hass.states.get("sensor.count_tables") + assert state.state == "0.005" + assert ( + "Error rendering query SELECT {{ states.sensor.input1.state | int / 1000}} as value" + " LIMIT 1;: ValueError: Template error: int got invalid input 'on' when rendering" + " template 'SELECT {{ states.sensor.input1.state | int / 1000}} as value LIMIT 1;'" + " but no default was specified" in caplog.text + ) + + async def test_query_limit(recorder_mock: Recorder, hass: HomeAssistant) -> None: """Test the SQL sensor with a query containing 'LIMIT' in lowercase.""" options = { @@ -641,17 +720,14 @@ async def test_query_recover_from_rollback( CONF_UNIQUE_ID: "very_unique_id", } await init_integration(hass, title="Select value SQL query", options=options) - platforms = async_get_platforms(hass, "sql") - sql_entity = platforms[0].entities["sensor.select_value_sql_query"] state = hass.states.get("sensor.select_value_sql_query") assert state.state == "5" assert state.attributes["value"] == 5 - with patch.object( - sql_entity, - "_lambda_stmt", - generate_lambda_stmt("Faulty syntax create operational issue"), + with patch( + "homeassistant.components.sql.sensor.generate_lambda_stmt", + return_value=generate_lambda_stmt("Faulty syntax create operational issue"), ): freezer.tick(timedelta(minutes=1)) async_fire_time_changed(hass) diff --git a/tests/components/sql/test_services.py b/tests/components/sql/test_services.py index ad1fa202153..0ef2f144a01 100644 --- a/tests/components/sql/test_services.py +++ b/tests/components/sql/test_services.py @@ -153,7 +153,7 @@ async def test_query_service_invalid_query_not_select( await async_setup_component(hass, DOMAIN, {}) await hass.async_block_till_done() - with pytest.raises(vol.Invalid, match="Only SELECT queries allowed"): + with pytest.raises(vol.Invalid, match="SQL query must be of type SELECT"): await hass.services.async_call( DOMAIN, SERVICE_QUERY, @@ -171,7 +171,7 @@ async def test_query_service_sqlalchemy_error( await async_setup_component(hass, DOMAIN, {}) await hass.async_block_till_done() - with pytest.raises(MultipleInvalid, match="Invalid SQL query"): + with pytest.raises(MultipleInvalid, match="SQL query is empty or unknown type"): await hass.services.async_call( DOMAIN, SERVICE_QUERY, diff --git a/tests/components/sql/test_util.py b/tests/components/sql/test_util.py index 7023fb17cc2..9df84d061d8 100644 --- a/tests/components/sql/test_util.py +++ b/tests/components/sql/test_util.py @@ -13,6 +13,7 @@ from homeassistant.components.sql.util import ( validate_sql_select, ) from homeassistant.core import HomeAssistant +from homeassistant.helpers.template import Template async def test_resolve_db_url_when_none_configured( @@ -39,27 +40,27 @@ async def test_resolve_db_url_when_configured(hass: HomeAssistant) -> None: [ ( "DROP TABLE *", - "Only SELECT queries allowed", + "SQL query must be of type SELECT", ), ( "SELECT5 as value", - "Invalid SQL query", + "SQL query is empty or unknown type", ), ( ";;", - "Invalid SQL query", + "SQL query is empty or unknown type", ), ( "UPDATE states SET state = 999999 WHERE state_id = 11125", - "Only SELECT queries allowed", + "SQL query must be of type SELECT", ), ( "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;", - "Only SELECT queries allowed", + "SQL query must be of type SELECT", ), ( "SELECT 5 as value; UPDATE states SET state = 10;", - "Multiple SQL queries are not supported", + "Multiple SQL statements are not allowed", ), ], ) @@ -70,7 +71,7 @@ async def test_invalid_sql_queries( ) -> None: """Test that various invalid or disallowed SQL queries raise the correct exception.""" with pytest.raises(vol.Invalid, match=expected_error_message): - validate_sql_select(sql_query) + validate_sql_select(Template(sql_query, hass)) @pytest.mark.parametrize(