From 144fc2a443be18740732cd3f4ac9afaceade1e52 Mon Sep 17 00:00:00 2001 From: David Rapan Date: Sun, 2 Nov 2025 18:49:18 +0100 Subject: [PATCH] Refactor SQL's data conversion (#155598) --- homeassistant/components/sql/sensor.py | 13 ++--------- homeassistant/components/sql/services.py | 12 ++-------- homeassistant/components/sql/util.py | 16 +++++++++++++ tests/components/sql/test_util.py | 29 +++++++++++++++++++++++- 4 files changed, 48 insertions(+), 22 deletions(-) diff --git a/homeassistant/components/sql/sensor.py b/homeassistant/components/sql/sensor.py index 508365b5c0d..c8885cd2377 100644 --- a/homeassistant/components/sql/sensor.py +++ b/homeassistant/components/sql/sensor.py @@ -2,8 +2,6 @@ from __future__ import annotations -from datetime import date -import decimal import logging from typing import Any @@ -43,6 +41,7 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN from .util import ( async_create_sessionmaker, + convert_value, generate_lambda_stmt, redact_credentials, resolve_db_url, @@ -253,7 +252,6 @@ class SQLSensor(ManualTriggerSensorEntity): def _update(self) -> None: """Retrieve sensor data from the query.""" data = None - extra_state_attributes = {} self._attr_extra_state_attributes = {} sess: scoped_session = self.sessionmaker() try: @@ -272,14 +270,7 @@ class SQLSensor(ManualTriggerSensorEntity): _LOGGER.debug("Query %s result in %s", self._query, res.items()) data = res[self._column_name] for key, value in res.items(): - if isinstance(value, decimal.Decimal): - value = float(value) - elif isinstance(value, date): - value = value.isoformat() - elif isinstance(value, (bytes, bytearray)): - value = f"0x{value.hex()}" - extra_state_attributes[key] = value - self._attr_extra_state_attributes[key] = value + self._attr_extra_state_attributes[key] = convert_value(value) if data is not None and isinstance(data, (bytes, bytearray)): data = f"0x{data.hex()}" diff --git a/homeassistant/components/sql/services.py b/homeassistant/components/sql/services.py index c7b74bd82b6..dc31064d3ec 100644 --- a/homeassistant/components/sql/services.py +++ b/homeassistant/components/sql/services.py @@ -2,8 +2,6 @@ from __future__ import annotations -import datetime -import decimal import logging from sqlalchemy.engine import Result @@ -26,6 +24,7 @@ from homeassistant.util.json import JsonValueType from .const import CONF_QUERY, DOMAIN from .util import ( async_create_sessionmaker, + convert_value, generate_lambda_stmt, redact_credentials, resolve_db_url, @@ -88,14 +87,7 @@ async def _async_query_service( for row in result.mappings(): processed_row: dict[str, JsonValueType] = {} for key, value in row.items(): - if isinstance(value, decimal.Decimal): - processed_row[key] = float(value) - elif isinstance(value, datetime.date): - processed_row[key] = value.isoformat() - elif isinstance(value, (bytes, bytearray)): - processed_row[key] = f"0x{value.hex()}" - else: - processed_row[key] = value + processed_row[key] = convert_value(value) rows.append(processed_row) return rows finally: diff --git a/homeassistant/components/sql/util.py b/homeassistant/components/sql/util.py index 0200a83c9e8..cc6f1bb5ea1 100644 --- a/homeassistant/components/sql/util.py +++ b/homeassistant/components/sql/util.py @@ -2,7 +2,10 @@ from __future__ import annotations +from datetime import date +from decimal import Decimal import logging +from typing import Any import sqlalchemy from sqlalchemy import lambda_stmt @@ -223,3 +226,16 @@ def generate_lambda_stmt(query: str) -> StatementLambdaElement: """Generate the lambda statement.""" text = sqlalchemy.text(query) return lambda_stmt(lambda: text, lambda_cache=_SQL_LAMBDA_CACHE) + + +def convert_value(value: Any) -> Any: + """Convert value.""" + match value: + case Decimal(): + return float(value) + case date(): + return value.isoformat() + case bytes() | bytearray(): + return f"0x{value.hex()}" + case _: + return value diff --git a/tests/components/sql/test_util.py b/tests/components/sql/test_util.py index 737a5e4a41b..7023fb17cc2 100644 --- a/tests/components/sql/test_util.py +++ b/tests/components/sql/test_util.py @@ -1,10 +1,17 @@ """Test the sql utils.""" +from datetime import UTC, date, datetime +from decimal import Decimal + import pytest import voluptuous as vol from homeassistant.components.recorder import Recorder, get_instance -from homeassistant.components.sql.util import resolve_db_url, validate_sql_select +from homeassistant.components.sql.util import ( + convert_value, + resolve_db_url, + validate_sql_select, +) from homeassistant.core import HomeAssistant @@ -64,3 +71,23 @@ async def test_invalid_sql_queries( """Test that various invalid or disallowed SQL queries raise the correct exception.""" with pytest.raises(vol.Invalid, match=expected_error_message): validate_sql_select(sql_query) + + +@pytest.mark.parametrize( + ("input", "expected_output"), + [ + (Decimal("199.99"), 199.99), + (date(2023, 1, 15), "2023-01-15"), + (datetime(2023, 1, 15, 12, 30, 45, tzinfo=UTC), "2023-01-15T12:30:45+00:00"), + (b"\xde\xad\xbe\xef", "0xdeadbeef"), + ("deadbeef", "deadbeef"), + (199.99, 199.99), + (69, 69), + ], +) +async def test_value_conversion( + input: Decimal | date | datetime | bytes | str | float, + expected_output: str | float, +) -> None: + """Test value conversion.""" + assert convert_value(input) == expected_output