1
0
mirror of https://github.com/home-assistant/core.git synced 2026-03-01 06:16:29 +00:00

Refactor SQL's data conversion (#155598)

This commit is contained in:
David Rapan
2025-11-02 18:49:18 +01:00
committed by GitHub
parent c67e005b2c
commit 144fc2a443
4 changed files with 48 additions and 22 deletions

View File

@@ -2,8 +2,6 @@
from __future__ import annotations
from datetime import date
import decimal
import logging
from typing import Any
@@ -43,6 +41,7 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
from .util import (
async_create_sessionmaker,
convert_value,
generate_lambda_stmt,
redact_credentials,
resolve_db_url,
@@ -253,7 +252,6 @@ class SQLSensor(ManualTriggerSensorEntity):
def _update(self) -> None:
"""Retrieve sensor data from the query."""
data = None
extra_state_attributes = {}
self._attr_extra_state_attributes = {}
sess: scoped_session = self.sessionmaker()
try:
@@ -272,14 +270,7 @@ class SQLSensor(ManualTriggerSensorEntity):
_LOGGER.debug("Query %s result in %s", self._query, res.items())
data = res[self._column_name]
for key, value in res.items():
if isinstance(value, decimal.Decimal):
value = float(value)
elif isinstance(value, date):
value = value.isoformat()
elif isinstance(value, (bytes, bytearray)):
value = f"0x{value.hex()}"
extra_state_attributes[key] = value
self._attr_extra_state_attributes[key] = value
self._attr_extra_state_attributes[key] = convert_value(value)
if data is not None and isinstance(data, (bytes, bytearray)):
data = f"0x{data.hex()}"

View File

@@ -2,8 +2,6 @@
from __future__ import annotations
import datetime
import decimal
import logging
from sqlalchemy.engine import Result
@@ -26,6 +24,7 @@ from homeassistant.util.json import JsonValueType
from .const import CONF_QUERY, DOMAIN
from .util import (
async_create_sessionmaker,
convert_value,
generate_lambda_stmt,
redact_credentials,
resolve_db_url,
@@ -88,14 +87,7 @@ async def _async_query_service(
for row in result.mappings():
processed_row: dict[str, JsonValueType] = {}
for key, value in row.items():
if isinstance(value, decimal.Decimal):
processed_row[key] = float(value)
elif isinstance(value, datetime.date):
processed_row[key] = value.isoformat()
elif isinstance(value, (bytes, bytearray)):
processed_row[key] = f"0x{value.hex()}"
else:
processed_row[key] = value
processed_row[key] = convert_value(value)
rows.append(processed_row)
return rows
finally:

View File

@@ -2,7 +2,10 @@
from __future__ import annotations
from datetime import date
from decimal import Decimal
import logging
from typing import Any
import sqlalchemy
from sqlalchemy import lambda_stmt
@@ -223,3 +226,16 @@ def generate_lambda_stmt(query: str) -> StatementLambdaElement:
"""Generate the lambda statement."""
text = sqlalchemy.text(query)
return lambda_stmt(lambda: text, lambda_cache=_SQL_LAMBDA_CACHE)
def convert_value(value: Any) -> Any:
"""Convert value."""
match value:
case Decimal():
return float(value)
case date():
return value.isoformat()
case bytes() | bytearray():
return f"0x{value.hex()}"
case _:
return value

View File

@@ -1,10 +1,17 @@
"""Test the sql utils."""
from datetime import UTC, date, datetime
from decimal import Decimal
import pytest
import voluptuous as vol
from homeassistant.components.recorder import Recorder, get_instance
from homeassistant.components.sql.util import resolve_db_url, validate_sql_select
from homeassistant.components.sql.util import (
convert_value,
resolve_db_url,
validate_sql_select,
)
from homeassistant.core import HomeAssistant
@@ -64,3 +71,23 @@ async def test_invalid_sql_queries(
"""Test that various invalid or disallowed SQL queries raise the correct exception."""
with pytest.raises(vol.Invalid, match=expected_error_message):
validate_sql_select(sql_query)
@pytest.mark.parametrize(
("input", "expected_output"),
[
(Decimal("199.99"), 199.99),
(date(2023, 1, 15), "2023-01-15"),
(datetime(2023, 1, 15, 12, 30, 45, tzinfo=UTC), "2023-01-15T12:30:45+00:00"),
(b"\xde\xad\xbe\xef", "0xdeadbeef"),
("deadbeef", "deadbeef"),
(199.99, 199.99),
(69, 69),
],
)
async def test_value_conversion(
input: Decimal | date | datetime | bytes | str | float,
expected_output: str | float,
) -> None:
"""Test value conversion."""
assert convert_value(input) == expected_output