mirror of
https://github.com/home-assistant/core.git
synced 2026-03-01 14:25:31 +00:00
Refactor SQL's data conversion (#155598)
This commit is contained in:
@@ -2,8 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
import decimal
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -43,6 +41,7 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
|
||||
from .util import (
|
||||
async_create_sessionmaker,
|
||||
convert_value,
|
||||
generate_lambda_stmt,
|
||||
redact_credentials,
|
||||
resolve_db_url,
|
||||
@@ -253,7 +252,6 @@ class SQLSensor(ManualTriggerSensorEntity):
|
||||
def _update(self) -> None:
|
||||
"""Retrieve sensor data from the query."""
|
||||
data = None
|
||||
extra_state_attributes = {}
|
||||
self._attr_extra_state_attributes = {}
|
||||
sess: scoped_session = self.sessionmaker()
|
||||
try:
|
||||
@@ -272,14 +270,7 @@ class SQLSensor(ManualTriggerSensorEntity):
|
||||
_LOGGER.debug("Query %s result in %s", self._query, res.items())
|
||||
data = res[self._column_name]
|
||||
for key, value in res.items():
|
||||
if isinstance(value, decimal.Decimal):
|
||||
value = float(value)
|
||||
elif isinstance(value, date):
|
||||
value = value.isoformat()
|
||||
elif isinstance(value, (bytes, bytearray)):
|
||||
value = f"0x{value.hex()}"
|
||||
extra_state_attributes[key] = value
|
||||
self._attr_extra_state_attributes[key] = value
|
||||
self._attr_extra_state_attributes[key] = convert_value(value)
|
||||
|
||||
if data is not None and isinstance(data, (bytes, bytearray)):
|
||||
data = f"0x{data.hex()}"
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import decimal
|
||||
import logging
|
||||
|
||||
from sqlalchemy.engine import Result
|
||||
@@ -26,6 +24,7 @@ from homeassistant.util.json import JsonValueType
|
||||
from .const import CONF_QUERY, DOMAIN
|
||||
from .util import (
|
||||
async_create_sessionmaker,
|
||||
convert_value,
|
||||
generate_lambda_stmt,
|
||||
redact_credentials,
|
||||
resolve_db_url,
|
||||
@@ -88,14 +87,7 @@ async def _async_query_service(
|
||||
for row in result.mappings():
|
||||
processed_row: dict[str, JsonValueType] = {}
|
||||
for key, value in row.items():
|
||||
if isinstance(value, decimal.Decimal):
|
||||
processed_row[key] = float(value)
|
||||
elif isinstance(value, datetime.date):
|
||||
processed_row[key] = value.isoformat()
|
||||
elif isinstance(value, (bytes, bytearray)):
|
||||
processed_row[key] = f"0x{value.hex()}"
|
||||
else:
|
||||
processed_row[key] = value
|
||||
processed_row[key] = convert_value(value)
|
||||
rows.append(processed_row)
|
||||
return rows
|
||||
finally:
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import lambda_stmt
|
||||
@@ -223,3 +226,16 @@ def generate_lambda_stmt(query: str) -> StatementLambdaElement:
|
||||
"""Generate the lambda statement."""
|
||||
text = sqlalchemy.text(query)
|
||||
return lambda_stmt(lambda: text, lambda_cache=_SQL_LAMBDA_CACHE)
|
||||
|
||||
|
||||
def convert_value(value: Any) -> Any:
|
||||
"""Convert value."""
|
||||
match value:
|
||||
case Decimal():
|
||||
return float(value)
|
||||
case date():
|
||||
return value.isoformat()
|
||||
case bytes() | bytearray():
|
||||
return f"0x{value.hex()}"
|
||||
case _:
|
||||
return value
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
"""Test the sql utils."""
|
||||
|
||||
from datetime import UTC, date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.recorder import Recorder, get_instance
|
||||
from homeassistant.components.sql.util import resolve_db_url, validate_sql_select
|
||||
from homeassistant.components.sql.util import (
|
||||
convert_value,
|
||||
resolve_db_url,
|
||||
validate_sql_select,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
@@ -64,3 +71,23 @@ async def test_invalid_sql_queries(
|
||||
"""Test that various invalid or disallowed SQL queries raise the correct exception."""
|
||||
with pytest.raises(vol.Invalid, match=expected_error_message):
|
||||
validate_sql_select(sql_query)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("input", "expected_output"),
|
||||
[
|
||||
(Decimal("199.99"), 199.99),
|
||||
(date(2023, 1, 15), "2023-01-15"),
|
||||
(datetime(2023, 1, 15, 12, 30, 45, tzinfo=UTC), "2023-01-15T12:30:45+00:00"),
|
||||
(b"\xde\xad\xbe\xef", "0xdeadbeef"),
|
||||
("deadbeef", "deadbeef"),
|
||||
(199.99, 199.99),
|
||||
(69, 69),
|
||||
],
|
||||
)
|
||||
async def test_value_conversion(
|
||||
input: Decimal | date | datetime | bytes | str | float,
|
||||
expected_output: str | float,
|
||||
) -> None:
|
||||
"""Test value conversion."""
|
||||
assert convert_value(input) == expected_output
|
||||
|
||||
Reference in New Issue
Block a user