1
0
mirror of https://github.com/home-assistant/core.git synced 2026-02-15 07:36:16 +00:00

Add sql.query action (#147260)

This commit is contained in:
tronikos
2025-10-23 02:08:37 -07:00
committed by GitHub
parent 5d644815fa
commit 6c919e698f
6 changed files with 425 additions and 0 deletions

View File

@@ -39,6 +39,7 @@ from .const import (
DOMAIN,
PLATFORMS,
)
from .services import async_setup_services
from .util import redact_credentials, validate_sql_select
_LOGGER = logging.getLogger(__name__)
@@ -71,6 +72,8 @@ CONFIG_SCHEMA = vol.Schema(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up SQL from yaml config."""
async_setup_services(hass)
if (conf := config.get(DOMAIN)) is None:
return True

View File

@@ -0,0 +1,7 @@
{
"services": {
"query": {
"service": "mdi:database-search"
}
}
}

View File

@@ -0,0 +1,131 @@
"""Services for the SQL integration."""
from __future__ import annotations
import datetime
import decimal
import logging
from sqlalchemy.engine import Result
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
import voluptuous as vol
from homeassistant.components.recorder import CONF_DB_URL, get_instance
from homeassistant.core import (
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
callback,
)
from homeassistant.exceptions import ServiceValidationError
from homeassistant.helpers import config_validation as cv
from homeassistant.util.json import JsonValueType
from .const import CONF_QUERY, DOMAIN
from .util import (
async_create_sessionmaker,
generate_lambda_stmt,
redact_credentials,
resolve_db_url,
validate_query,
validate_sql_select,
)
_LOGGER = logging.getLogger(__name__)
SERVICE_QUERY = "query"
SERVICE_QUERY_SCHEMA = vol.Schema(
{
vol.Required(CONF_QUERY): vol.All(cv.string, validate_sql_select),
vol.Optional(CONF_DB_URL): cv.string,
}
)
async def _async_query_service(
call: ServiceCall,
) -> ServiceResponse:
"""Execute a SQL query service and return the result."""
db_url = resolve_db_url(call.hass, call.data.get(CONF_DB_URL))
query_str = call.data[CONF_QUERY]
(
sessmaker,
uses_recorder_db,
use_database_executor,
) = await async_create_sessionmaker(call.hass, db_url)
try:
validate_query(call.hass, query_str, uses_recorder_db, None)
except ValueError as err:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="query_not_allowed",
translation_placeholders={"error": str(err)},
) from err
if sessmaker is None:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="db_connection_failed",
translation_placeholders={"db_url": redact_credentials(db_url)},
)
def _execute_and_convert_query() -> list[JsonValueType]:
"""Execute the query and return the results with converted types."""
sess: Session = sessmaker()
try:
result: Result = sess.execute(generate_lambda_stmt(query_str))
except SQLAlchemyError as err:
_LOGGER.debug(
"Error executing query %s: %s",
query_str,
redact_credentials(str(err)),
)
sess.rollback()
raise
else:
rows: list[JsonValueType] = []
for row in result.mappings():
processed_row: dict[str, JsonValueType] = {}
for key, value in row.items():
if isinstance(value, decimal.Decimal):
processed_row[key] = float(value)
elif isinstance(value, datetime.date):
processed_row[key] = value.isoformat()
elif isinstance(value, (bytes, bytearray)):
processed_row[key] = f"0x{value.hex()}"
else:
processed_row[key] = value
rows.append(processed_row)
return rows
finally:
sess.close()
try:
if use_database_executor:
result = await get_instance(call.hass).async_add_executor_job(
_execute_and_convert_query
)
else:
result = await call.hass.async_add_executor_job(_execute_and_convert_query)
except SQLAlchemyError as err:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="query_execution_error",
translation_placeholders={"error": redact_credentials(str(err))},
) from err
return {"result": result}
@callback
def async_setup_services(hass: HomeAssistant) -> None:
"""Set up the services for the SQL integration."""
hass.services.async_register(
DOMAIN,
SERVICE_QUERY,
_async_query_service,
schema=SERVICE_QUERY_SCHEMA,
supports_response=SupportsResponse.ONLY,
)

View File

@@ -0,0 +1,28 @@
# Describes the format for services provided by the SQL integration.
query:
fields:
query:
required: true
example: |
SELECT
states.state,
last_updated_ts
FROM
states
INNER JOIN states_meta ON
states.metadata_id = states_meta.metadata_id
WHERE
states_meta.entity_id = 'sun.sun'
ORDER BY
last_updated_ts DESC
LIMIT
10;
selector:
text:
multiline: true
db_url:
required: false
example: "sqlite:////config/home-assistant_v2.db"
selector:
text:

View File

@@ -51,6 +51,33 @@
}
}
},
"exceptions": {
"db_connection_failed": {
"message": "Failed to connect to the database: {db_url}"
},
"query_execution_error": {
"message": "An error occurred when executing the query: {error}"
},
"query_not_allowed": {
"message": "The provided query is not allowed: {error}"
}
},
"services": {
"query": {
"name": "Query",
"description": "Executes a SQL query and returns the result.",
"fields": {
"query": {
"name": "Query",
"description": "The SELECT query to execute."
},
"db_url": {
"name": "Database URL",
"description": "The URL of the database to connect to. If not provided, the default Home Assistant recorder database will be used."
}
}
}
},
"options": {
"step": {
"init": {

View File

@@ -0,0 +1,229 @@
"""Tests for the SQL integration services."""
from __future__ import annotations
from pathlib import Path
import sqlite3
from unittest.mock import patch
import pytest
import voluptuous as vol
from voluptuous import MultipleInvalid
from homeassistant.components.recorder import Recorder
from homeassistant.components.sql.const import DOMAIN
from homeassistant.components.sql.services import SERVICE_QUERY
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ServiceValidationError
from homeassistant.setup import async_setup_component
from tests.components.recorder.common import async_wait_recording_done
async def test_query_service_recorder_db(
recorder_mock: Recorder,
hass: HomeAssistant,
) -> None:
"""Test the query service with the default recorder database."""
await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
# Populate the recorder database with some data
hass.states.async_set("sensor.test", "123", {"attr": "value"})
hass.states.async_set("sensor.test2", "456")
await async_wait_recording_done(hass)
response = await hass.services.async_call(
DOMAIN,
SERVICE_QUERY,
{
"query": (
"SELECT states_meta.entity_id, states.state "
"FROM states INNER JOIN states_meta ON states.metadata_id = states_meta.metadata_id "
"WHERE states_meta.entity_id LIKE 'sensor.test%' ORDER BY states_meta.entity_id"
)
},
blocking=True,
return_response=True,
)
assert response == {
"result": [
{"entity_id": "sensor.test", "state": "123"},
{"entity_id": "sensor.test2", "state": "456"},
]
}
async def test_query_service_external_db(hass: HomeAssistant, tmp_path: Path) -> None:
"""Test the query service with an external database via db_url."""
db_path = tmp_path / "test.db"
db_url = f"sqlite:///{db_path}"
# Create and populate the external database
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE users (name TEXT, age INTEGER)")
conn.execute("INSERT INTO users (name, age) VALUES ('Alice', 30), ('Bob', 25)")
conn.commit()
conn.close()
await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
response = await hass.services.async_call(
DOMAIN,
SERVICE_QUERY,
{"query": "SELECT name, age FROM users ORDER BY age", "db_url": db_url},
blocking=True,
return_response=True,
)
assert response == {
"result": [
{"name": "Bob", "age": 25},
{"name": "Alice", "age": 30},
]
}
async def test_query_service_data_conversion(
hass: HomeAssistant, tmp_path: Path
) -> None:
"""Test the query service correctly converts data types."""
db_path = tmp_path / "test_types.db"
db_url = f"sqlite:///{db_path}"
conn = sqlite3.connect(db_path)
conn.execute(
"CREATE TABLE data (id INTEGER, cost DECIMAL(10, 2), event_date DATE, raw BLOB)"
)
conn.execute(
"INSERT INTO data (id, cost, event_date, raw) VALUES (1, 199.99, '2023-01-15', X'DEADBEEF')"
)
conn.commit()
conn.close()
await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
response = await hass.services.async_call(
DOMAIN,
SERVICE_QUERY,
{"query": "SELECT * FROM data", "db_url": db_url},
blocking=True,
return_response=True,
)
assert response == {
"result": [
{
"id": 1,
"cost": 199.99, # Converted from Decimal to float
"event_date": "2023-01-15", # Converted from date to ISO string
"raw": "0xdeadbeef", # Converted from bytes to hex string
}
]
}
async def test_query_service_no_results(
recorder_mock: Recorder,
hass: HomeAssistant,
) -> None:
"""Test the query service when a query returns no results."""
await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
response = await hass.services.async_call(
DOMAIN,
SERVICE_QUERY,
{"query": "SELECT * FROM states"},
blocking=True,
return_response=True,
)
assert response == {"result": []}
async def test_query_service_invalid_query_not_select(
recorder_mock: Recorder,
hass: HomeAssistant,
) -> None:
"""Test the service rejects non-SELECT queries."""
await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
with pytest.raises(vol.Invalid, match="Only SELECT queries allowed"):
await hass.services.async_call(
DOMAIN,
SERVICE_QUERY,
{"query": "UPDATE states SET state = 'hacked'"},
blocking=True,
return_response=True,
)
async def test_query_service_sqlalchemy_error(
recorder_mock: Recorder,
hass: HomeAssistant,
) -> None:
"""Test the service handles SQLAlchemy errors during query execution."""
await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
with pytest.raises(MultipleInvalid, match="Invalid SQL query"):
await hass.services.async_call(
DOMAIN,
SERVICE_QUERY,
# Syntactically incorrect query
{"query": "SELEC * FROM states"},
blocking=True,
return_response=True,
)
async def test_query_service_invalid_db_url(hass: HomeAssistant) -> None:
"""Test the service handles an invalid database URL."""
await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
with (
patch(
"homeassistant.components.sql.util._validate_and_get_session_maker_for_db_url",
return_value=None,
),
pytest.raises(
ServiceValidationError, match="Failed to connect to the database"
),
):
await hass.services.async_call(
DOMAIN,
SERVICE_QUERY,
{
"query": "SELECT 1",
"db_url": "postgresql://user:pass@host:123/dbname",
},
blocking=True,
return_response=True,
)
async def test_query_service_performance_issue_validation(
recorder_mock: Recorder,
hass: HomeAssistant,
) -> None:
"""Test the service validates queries against the recorder for performance issues."""
await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done()
with pytest.raises(
ServiceValidationError,
match="The provided query is not allowed: Query contains entity_id but does not reference states_meta",
):
await hass.services.async_call(
DOMAIN,
SERVICE_QUERY,
{"query": "SELECT entity_id FROM states"},
blocking=True,
return_response=True,
)