mirror of
https://github.com/home-assistant/core.git
synced 2026-02-15 07:36:16 +00:00
Add sql.query action (#147260)
This commit is contained in:
@@ -39,6 +39,7 @@ from .const import (
|
||||
DOMAIN,
|
||||
PLATFORMS,
|
||||
)
|
||||
from .services import async_setup_services
|
||||
from .util import redact_credentials, validate_sql_select
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -71,6 +72,8 @@ CONFIG_SCHEMA = vol.Schema(
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up SQL from yaml config."""
|
||||
async_setup_services(hass)
|
||||
|
||||
if (conf := config.get(DOMAIN)) is None:
|
||||
return True
|
||||
|
||||
|
||||
7
homeassistant/components/sql/icons.json
Normal file
7
homeassistant/components/sql/icons.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"services": {
|
||||
"query": {
|
||||
"service": "mdi:database-search"
|
||||
}
|
||||
}
|
||||
}
|
||||
131
homeassistant/components/sql/services.py
Normal file
131
homeassistant/components/sql/services.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Services for the SQL integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import decimal
|
||||
import logging
|
||||
|
||||
from sqlalchemy.engine import Result
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.recorder import CONF_DB_URL, get_instance
|
||||
from homeassistant.core import (
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResponse,
|
||||
SupportsResponse,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.exceptions import ServiceValidationError
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.util.json import JsonValueType
|
||||
|
||||
from .const import CONF_QUERY, DOMAIN
|
||||
from .util import (
|
||||
async_create_sessionmaker,
|
||||
generate_lambda_stmt,
|
||||
redact_credentials,
|
||||
resolve_db_url,
|
||||
validate_query,
|
||||
validate_sql_select,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SERVICE_QUERY = "query"
|
||||
SERVICE_QUERY_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_QUERY): vol.All(cv.string, validate_sql_select),
|
||||
vol.Optional(CONF_DB_URL): cv.string,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def _async_query_service(
|
||||
call: ServiceCall,
|
||||
) -> ServiceResponse:
|
||||
"""Execute a SQL query service and return the result."""
|
||||
db_url = resolve_db_url(call.hass, call.data.get(CONF_DB_URL))
|
||||
query_str = call.data[CONF_QUERY]
|
||||
(
|
||||
sessmaker,
|
||||
uses_recorder_db,
|
||||
use_database_executor,
|
||||
) = await async_create_sessionmaker(call.hass, db_url)
|
||||
try:
|
||||
validate_query(call.hass, query_str, uses_recorder_db, None)
|
||||
except ValueError as err:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="query_not_allowed",
|
||||
translation_placeholders={"error": str(err)},
|
||||
) from err
|
||||
if sessmaker is None:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="db_connection_failed",
|
||||
translation_placeholders={"db_url": redact_credentials(db_url)},
|
||||
)
|
||||
|
||||
def _execute_and_convert_query() -> list[JsonValueType]:
|
||||
"""Execute the query and return the results with converted types."""
|
||||
sess: Session = sessmaker()
|
||||
try:
|
||||
result: Result = sess.execute(generate_lambda_stmt(query_str))
|
||||
except SQLAlchemyError as err:
|
||||
_LOGGER.debug(
|
||||
"Error executing query %s: %s",
|
||||
query_str,
|
||||
redact_credentials(str(err)),
|
||||
)
|
||||
sess.rollback()
|
||||
raise
|
||||
else:
|
||||
rows: list[JsonValueType] = []
|
||||
for row in result.mappings():
|
||||
processed_row: dict[str, JsonValueType] = {}
|
||||
for key, value in row.items():
|
||||
if isinstance(value, decimal.Decimal):
|
||||
processed_row[key] = float(value)
|
||||
elif isinstance(value, datetime.date):
|
||||
processed_row[key] = value.isoformat()
|
||||
elif isinstance(value, (bytes, bytearray)):
|
||||
processed_row[key] = f"0x{value.hex()}"
|
||||
else:
|
||||
processed_row[key] = value
|
||||
rows.append(processed_row)
|
||||
return rows
|
||||
finally:
|
||||
sess.close()
|
||||
|
||||
try:
|
||||
if use_database_executor:
|
||||
result = await get_instance(call.hass).async_add_executor_job(
|
||||
_execute_and_convert_query
|
||||
)
|
||||
else:
|
||||
result = await call.hass.async_add_executor_job(_execute_and_convert_query)
|
||||
except SQLAlchemyError as err:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="query_execution_error",
|
||||
translation_placeholders={"error": redact_credentials(str(err))},
|
||||
) from err
|
||||
|
||||
return {"result": result}
|
||||
|
||||
|
||||
@callback
|
||||
def async_setup_services(hass: HomeAssistant) -> None:
|
||||
"""Set up the services for the SQL integration."""
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
SERVICE_QUERY,
|
||||
_async_query_service,
|
||||
schema=SERVICE_QUERY_SCHEMA,
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
28
homeassistant/components/sql/services.yaml
Normal file
28
homeassistant/components/sql/services.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
# Describes the format for services provided by the SQL integration.
|
||||
|
||||
query:
|
||||
fields:
|
||||
query:
|
||||
required: true
|
||||
example: |
|
||||
SELECT
|
||||
states.state,
|
||||
last_updated_ts
|
||||
FROM
|
||||
states
|
||||
INNER JOIN states_meta ON
|
||||
states.metadata_id = states_meta.metadata_id
|
||||
WHERE
|
||||
states_meta.entity_id = 'sun.sun'
|
||||
ORDER BY
|
||||
last_updated_ts DESC
|
||||
LIMIT
|
||||
10;
|
||||
selector:
|
||||
text:
|
||||
multiline: true
|
||||
db_url:
|
||||
required: false
|
||||
example: "sqlite:////config/home-assistant_v2.db"
|
||||
selector:
|
||||
text:
|
||||
@@ -51,6 +51,33 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"exceptions": {
|
||||
"db_connection_failed": {
|
||||
"message": "Failed to connect to the database: {db_url}"
|
||||
},
|
||||
"query_execution_error": {
|
||||
"message": "An error occurred when executing the query: {error}"
|
||||
},
|
||||
"query_not_allowed": {
|
||||
"message": "The provided query is not allowed: {error}"
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
"query": {
|
||||
"name": "Query",
|
||||
"description": "Executes a SQL query and returns the result.",
|
||||
"fields": {
|
||||
"query": {
|
||||
"name": "Query",
|
||||
"description": "The SELECT query to execute."
|
||||
},
|
||||
"db_url": {
|
||||
"name": "Database URL",
|
||||
"description": "The URL of the database to connect to. If not provided, the default Home Assistant recorder database will be used."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
|
||||
229
tests/components/sql/test_services.py
Normal file
229
tests/components/sql/test_services.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Tests for the SQL integration services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
from voluptuous import MultipleInvalid
|
||||
|
||||
from homeassistant.components.recorder import Recorder
|
||||
from homeassistant.components.sql.const import DOMAIN
|
||||
from homeassistant.components.sql.services import SERVICE_QUERY
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ServiceValidationError
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.components.recorder.common import async_wait_recording_done
|
||||
|
||||
|
||||
async def test_query_service_recorder_db(
|
||||
recorder_mock: Recorder,
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test the query service with the default recorder database."""
|
||||
await async_setup_component(hass, DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Populate the recorder database with some data
|
||||
hass.states.async_set("sensor.test", "123", {"attr": "value"})
|
||||
hass.states.async_set("sensor.test2", "456")
|
||||
await async_wait_recording_done(hass)
|
||||
|
||||
response = await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_QUERY,
|
||||
{
|
||||
"query": (
|
||||
"SELECT states_meta.entity_id, states.state "
|
||||
"FROM states INNER JOIN states_meta ON states.metadata_id = states_meta.metadata_id "
|
||||
"WHERE states_meta.entity_id LIKE 'sensor.test%' ORDER BY states_meta.entity_id"
|
||||
)
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
assert response == {
|
||||
"result": [
|
||||
{"entity_id": "sensor.test", "state": "123"},
|
||||
{"entity_id": "sensor.test2", "state": "456"},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
async def test_query_service_external_db(hass: HomeAssistant, tmp_path: Path) -> None:
|
||||
"""Test the query service with an external database via db_url."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
|
||||
# Create and populate the external database
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute("CREATE TABLE users (name TEXT, age INTEGER)")
|
||||
conn.execute("INSERT INTO users (name, age) VALUES ('Alice', 30), ('Bob', 25)")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
await async_setup_component(hass, DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
response = await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_QUERY,
|
||||
{"query": "SELECT name, age FROM users ORDER BY age", "db_url": db_url},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
assert response == {
|
||||
"result": [
|
||||
{"name": "Bob", "age": 25},
|
||||
{"name": "Alice", "age": 30},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
async def test_query_service_data_conversion(
|
||||
hass: HomeAssistant, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test the query service correctly converts data types."""
|
||||
db_path = tmp_path / "test_types.db"
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute(
|
||||
"CREATE TABLE data (id INTEGER, cost DECIMAL(10, 2), event_date DATE, raw BLOB)"
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO data (id, cost, event_date, raw) VALUES (1, 199.99, '2023-01-15', X'DEADBEEF')"
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
await async_setup_component(hass, DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
response = await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_QUERY,
|
||||
{"query": "SELECT * FROM data", "db_url": db_url},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
assert response == {
|
||||
"result": [
|
||||
{
|
||||
"id": 1,
|
||||
"cost": 199.99, # Converted from Decimal to float
|
||||
"event_date": "2023-01-15", # Converted from date to ISO string
|
||||
"raw": "0xdeadbeef", # Converted from bytes to hex string
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
async def test_query_service_no_results(
|
||||
recorder_mock: Recorder,
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test the query service when a query returns no results."""
|
||||
await async_setup_component(hass, DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
response = await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_QUERY,
|
||||
{"query": "SELECT * FROM states"},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
assert response == {"result": []}
|
||||
|
||||
|
||||
async def test_query_service_invalid_query_not_select(
|
||||
recorder_mock: Recorder,
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test the service rejects non-SELECT queries."""
|
||||
await async_setup_component(hass, DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with pytest.raises(vol.Invalid, match="Only SELECT queries allowed"):
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_QUERY,
|
||||
{"query": "UPDATE states SET state = 'hacked'"},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_query_service_sqlalchemy_error(
|
||||
recorder_mock: Recorder,
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test the service handles SQLAlchemy errors during query execution."""
|
||||
await async_setup_component(hass, DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with pytest.raises(MultipleInvalid, match="Invalid SQL query"):
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_QUERY,
|
||||
# Syntactically incorrect query
|
||||
{"query": "SELEC * FROM states"},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_query_service_invalid_db_url(hass: HomeAssistant) -> None:
|
||||
"""Test the service handles an invalid database URL."""
|
||||
await async_setup_component(hass, DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.sql.util._validate_and_get_session_maker_for_db_url",
|
||||
return_value=None,
|
||||
),
|
||||
pytest.raises(
|
||||
ServiceValidationError, match="Failed to connect to the database"
|
||||
),
|
||||
):
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_QUERY,
|
||||
{
|
||||
"query": "SELECT 1",
|
||||
"db_url": "postgresql://user:pass@host:123/dbname",
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_query_service_performance_issue_validation(
|
||||
recorder_mock: Recorder,
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test the service validates queries against the recorder for performance issues."""
|
||||
await async_setup_component(hass, DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with pytest.raises(
|
||||
ServiceValidationError,
|
||||
match="The provided query is not allowed: Query contains entity_id but does not reference states_meta",
|
||||
):
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_QUERY,
|
||||
{"query": "SELECT entity_id FROM states"},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
Reference in New Issue
Block a user