1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-24 12:59:34 +00:00

Add support for importing integrations in the executor (#111336)

* Add support for pre-imports at setup time

alternative solution to #111331

* refactor

* refactor

* refactor

* mark >1.0s integrations

* no point in executor if already loaded

* no point in executor if already loaded

* cleanup

* cleanup

* two more

* one more

* analytics loads a lot more integrations

* cloud

* debug

* psutil, hardwre

* try zha

* Update homeassistant/setup.py

* await

* comments

* coverage

* coverage

* coverage

* move logic to loader

* move logic to loader

* preserve comments
This commit is contained in:
J. Nick Koston
2024-02-26 09:49:43 -10:00
committed by GitHub
parent 75e59167de
commit 4ea1c5cc3c
25 changed files with 278 additions and 9 deletions

View File

@@ -14,6 +14,7 @@ import importlib
import logging
import pathlib
import sys
import time
from types import ModuleType
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypedDict, TypeVar, cast
@@ -179,6 +180,7 @@ class Manifest(TypedDict, total=False):
version: str
codeowners: list[str]
loggers: list[str]
import_executor: bool
single_config_entry: bool
@@ -658,6 +660,7 @@ class Integration:
self._all_dependencies_resolved = True
self._all_dependencies = set()
self._import_futures: dict[str, asyncio.Future[ModuleType]] = {}
_LOGGER.info("Loaded %s from %s", self.domain, pkg_path)
@cached_property
@@ -727,6 +730,11 @@ class Integration:
"""Return the integration type."""
return self.manifest.get("integration_type", "hub")
@cached_property
def import_executor(self) -> bool:
"""Import integration in the executor."""
return self.manifest.get("import_executor") or False
@property
def mqtt(self) -> list[str] | None:
"""Return Integration MQTT entries."""
@@ -826,8 +834,47 @@ class Integration:
return self._all_dependencies_resolved
async def async_get_component(self) -> ComponentProtocol:
"""Return the component.
This method will load the component if it't not already loaded
and will check if import_executor is set and load it in the executor,
otherwise it will load it in the event loop.
"""
if debug := _LOGGER.isEnabledFor(logging.DEBUG):
start = time.perf_counter()
domain = self.domain
load_executor = (
self.import_executor
and f"hass.components.{domain}" not in sys.modules
and f"custom_components.{domain}" not in sys.modules
)
# Some integrations fail on import because they call functions incorrectly.
# So we do it before validating config to catch these errors.
if load_executor:
comp = await self.hass.async_add_executor_job(self.get_component)
else:
comp = self.get_component()
if debug:
_LOGGER.debug(
"Component %s import took %.3f seconds (loaded_executor=%s)",
domain,
time.perf_counter() - start,
load_executor,
)
return comp
def get_component(self) -> ComponentProtocol:
"""Return the component."""
"""Return the component.
This method must be thread-safe as its called from the executor
and the event loop.
This is mostly a thin wrapper around importlib.import_module
with a dict cache which is thread-safe since importlib has
appropriate locks.
"""
cache: dict[str, ComponentProtocol] = self.hass.data[DATA_COMPONENTS]
if self.domain in cache:
return cache[self.domain]
@@ -846,10 +893,56 @@ class Integration:
return cache[self.domain]
def get_platform(self, platform_name: str) -> ModuleType:
async def async_get_platform(self, platform_name: str) -> ModuleType:
"""Return a platform for an integration."""
cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS]
domain = self.domain
full_name = f"{self.domain}.{platform_name}"
if platform := self._get_platform_cached(full_name):
return platform
if future := self._import_futures.get(full_name):
return await future
if debug := _LOGGER.isEnabledFor(logging.DEBUG):
start = time.perf_counter()
import_future = self.hass.loop.create_future()
self._import_futures[full_name] = import_future
load_executor = (
self.import_executor
and domain not in self.hass.config.components
and f"hass.components.{domain}" not in sys.modules
and f"custom_components.{domain}" not in sys.modules
)
try:
if load_executor:
platform = await self.hass.async_add_executor_job(
self._load_platform, platform_name
)
else:
platform = self._load_platform(platform_name)
import_future.set_result(platform)
except BaseException as ex:
import_future.set_exception(ex)
with suppress(BaseException):
# Clear the exception retrieved flag on the future since
# it will never be retrieved unless there
# are concurrent calls to async_get_platform
import_future.result()
raise
finally:
self._import_futures.pop(full_name)
if debug:
_LOGGER.debug(
"Loaded flow for %s in %.2fs (loaded_executor=%s)",
domain,
time.perf_counter() - start,
load_executor,
)
return platform
def _get_platform_cached(self, full_name: str) -> ModuleType | None:
"""Return a platform for an integration from cache."""
cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS]
if full_name in cache:
return cache[full_name]
@@ -859,12 +952,35 @@ class Integration:
if full_name in missing_platforms_cache:
raise missing_platforms_cache[full_name]
return None
def get_platform(self, platform_name: str) -> ModuleType:
"""Return a platform for an integration."""
if platform := self._get_platform_cached(f"{self.domain}.{platform_name}"):
return platform
return self._load_platform(platform_name)
def _load_platform(self, platform_name: str) -> ModuleType:
"""Load a platform for an integration.
This method must be thread-safe as its called from the executor
and the event loop.
This is mostly a thin wrapper around importlib.import_module
with a dict cache which is thread-safe since importlib has
appropriate locks.
"""
full_name = f"{self.domain}.{platform_name}"
cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS]
try:
cache[full_name] = self._import_platform(platform_name)
except ImportError as ex:
if self.domain in cache:
# If the domain is loaded, cache that the platform
# does not exist so we do not try to load it again
missing_platforms_cache: dict[str, ImportError] = self.hass.data[
DATA_MISSING_PLATFORMS
]
missing_platforms_cache[full_name] = ex
raise
except Exception as err:
@@ -880,7 +996,11 @@ class Integration:
return cache[full_name]
def _import_platform(self, platform_name: str) -> ModuleType:
"""Import the platform."""
"""Import the platform.
This method must be thread-safe as its called from the executor
and the event loop.
"""
return importlib.import_module(f"{self.pkg_path}.{platform_name}")
def __repr__(self) -> str: