mirror of
https://github.com/home-assistant/core.git
synced 2025-12-24 12:59:34 +00:00
Add recorder test fixture to enable persistent SQLite database (#121137)
* Add recorder test fixture to enable persistent SQLite database * Fix tests directly using async_test_home_assistant context manager
This commit is contained in:
@@ -12,6 +12,7 @@ import itertools
|
||||
import logging
|
||||
import os
|
||||
import reprlib
|
||||
from shutil import rmtree
|
||||
import sqlite3
|
||||
import ssl
|
||||
import threading
|
||||
@@ -1309,16 +1310,36 @@ def recorder_config() -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def persistent_database() -> bool:
|
||||
"""Fixture to control if database should persist when recorder is shut down in test.
|
||||
|
||||
When using sqlite, this uses on disk database instead of in memory database.
|
||||
This does nothing when using mysql or postgresql.
|
||||
|
||||
Note that the database is always destroyed in between tests.
|
||||
|
||||
To use a persistent database, tests can be marked with:
|
||||
@pytest.mark.parametrize("persistent_database", [True])
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recorder_db_url(
|
||||
pytestconfig: pytest.Config,
|
||||
hass_fixture_setup: list[bool],
|
||||
persistent_database: str,
|
||||
tmp_path_factory: pytest.TempPathFactory,
|
||||
) -> Generator[str]:
|
||||
"""Prepare a default database for tests and return a connection URL."""
|
||||
assert not hass_fixture_setup
|
||||
|
||||
db_url = cast(str, pytestconfig.getoption("dburl"))
|
||||
if db_url.startswith("mysql://"):
|
||||
if db_url == "sqlite://" and persistent_database:
|
||||
tmp_path = tmp_path_factory.mktemp("recorder")
|
||||
db_url = "sqlite:///" + str(tmp_path / "pytest.db")
|
||||
elif db_url.startswith("mysql://"):
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
import sqlalchemy_utils
|
||||
|
||||
@@ -1332,7 +1353,9 @@ def recorder_db_url(
|
||||
assert not sqlalchemy_utils.database_exists(db_url)
|
||||
sqlalchemy_utils.create_database(db_url, encoding="utf8")
|
||||
yield db_url
|
||||
if db_url.startswith("mysql://"):
|
||||
if db_url == "sqlite://" and persistent_database:
|
||||
rmtree(tmp_path, ignore_errors=True)
|
||||
elif db_url.startswith("mysql://"):
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
Reference in New Issue
Block a user