mirror of
https://github.com/home-assistant/core.git
synced 2025-12-24 12:59:34 +00:00
Add support for importing integrations in the executor (#111336)
* Add support for pre-imports at setup time alternative solution to #111331 * refactor * refactor * refactor * mark >1.0s integrations * no point in executor if already loaded * no point in executor if already loaded * cleanup * cleanup * two more * one more * analytics loads a lot more integrations * cloud * debug * psutil, hardwre * try zha * Update homeassistant/setup.py * await * comments * coverage * coverage * coverage * move logic to loader * move logic to loader * preserve comments
This commit is contained in:
@@ -14,6 +14,7 @@ import importlib
|
||||
import logging
|
||||
import pathlib
|
||||
import sys
|
||||
import time
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypedDict, TypeVar, cast
|
||||
|
||||
@@ -179,6 +180,7 @@ class Manifest(TypedDict, total=False):
|
||||
version: str
|
||||
codeowners: list[str]
|
||||
loggers: list[str]
|
||||
import_executor: bool
|
||||
single_config_entry: bool
|
||||
|
||||
|
||||
@@ -658,6 +660,7 @@ class Integration:
|
||||
self._all_dependencies_resolved = True
|
||||
self._all_dependencies = set()
|
||||
|
||||
self._import_futures: dict[str, asyncio.Future[ModuleType]] = {}
|
||||
_LOGGER.info("Loaded %s from %s", self.domain, pkg_path)
|
||||
|
||||
@cached_property
|
||||
@@ -727,6 +730,11 @@ class Integration:
|
||||
"""Return the integration type."""
|
||||
return self.manifest.get("integration_type", "hub")
|
||||
|
||||
@cached_property
|
||||
def import_executor(self) -> bool:
|
||||
"""Import integration in the executor."""
|
||||
return self.manifest.get("import_executor") or False
|
||||
|
||||
@property
|
||||
def mqtt(self) -> list[str] | None:
|
||||
"""Return Integration MQTT entries."""
|
||||
@@ -826,8 +834,47 @@ class Integration:
|
||||
|
||||
return self._all_dependencies_resolved
|
||||
|
||||
async def async_get_component(self) -> ComponentProtocol:
|
||||
"""Return the component.
|
||||
|
||||
This method will load the component if it't not already loaded
|
||||
and will check if import_executor is set and load it in the executor,
|
||||
otherwise it will load it in the event loop.
|
||||
"""
|
||||
if debug := _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
start = time.perf_counter()
|
||||
domain = self.domain
|
||||
load_executor = (
|
||||
self.import_executor
|
||||
and f"hass.components.{domain}" not in sys.modules
|
||||
and f"custom_components.{domain}" not in sys.modules
|
||||
)
|
||||
# Some integrations fail on import because they call functions incorrectly.
|
||||
# So we do it before validating config to catch these errors.
|
||||
if load_executor:
|
||||
comp = await self.hass.async_add_executor_job(self.get_component)
|
||||
else:
|
||||
comp = self.get_component()
|
||||
|
||||
if debug:
|
||||
_LOGGER.debug(
|
||||
"Component %s import took %.3f seconds (loaded_executor=%s)",
|
||||
domain,
|
||||
time.perf_counter() - start,
|
||||
load_executor,
|
||||
)
|
||||
return comp
|
||||
|
||||
def get_component(self) -> ComponentProtocol:
|
||||
"""Return the component."""
|
||||
"""Return the component.
|
||||
|
||||
This method must be thread-safe as its called from the executor
|
||||
and the event loop.
|
||||
|
||||
This is mostly a thin wrapper around importlib.import_module
|
||||
with a dict cache which is thread-safe since importlib has
|
||||
appropriate locks.
|
||||
"""
|
||||
cache: dict[str, ComponentProtocol] = self.hass.data[DATA_COMPONENTS]
|
||||
if self.domain in cache:
|
||||
return cache[self.domain]
|
||||
@@ -846,10 +893,56 @@ class Integration:
|
||||
|
||||
return cache[self.domain]
|
||||
|
||||
def get_platform(self, platform_name: str) -> ModuleType:
|
||||
async def async_get_platform(self, platform_name: str) -> ModuleType:
|
||||
"""Return a platform for an integration."""
|
||||
cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS]
|
||||
domain = self.domain
|
||||
full_name = f"{self.domain}.{platform_name}"
|
||||
if platform := self._get_platform_cached(full_name):
|
||||
return platform
|
||||
if future := self._import_futures.get(full_name):
|
||||
return await future
|
||||
if debug := _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
start = time.perf_counter()
|
||||
import_future = self.hass.loop.create_future()
|
||||
self._import_futures[full_name] = import_future
|
||||
load_executor = (
|
||||
self.import_executor
|
||||
and domain not in self.hass.config.components
|
||||
and f"hass.components.{domain}" not in sys.modules
|
||||
and f"custom_components.{domain}" not in sys.modules
|
||||
)
|
||||
try:
|
||||
if load_executor:
|
||||
platform = await self.hass.async_add_executor_job(
|
||||
self._load_platform, platform_name
|
||||
)
|
||||
else:
|
||||
platform = self._load_platform(platform_name)
|
||||
import_future.set_result(platform)
|
||||
except BaseException as ex:
|
||||
import_future.set_exception(ex)
|
||||
with suppress(BaseException):
|
||||
# Clear the exception retrieved flag on the future since
|
||||
# it will never be retrieved unless there
|
||||
# are concurrent calls to async_get_platform
|
||||
import_future.result()
|
||||
raise
|
||||
finally:
|
||||
self._import_futures.pop(full_name)
|
||||
|
||||
if debug:
|
||||
_LOGGER.debug(
|
||||
"Loaded flow for %s in %.2fs (loaded_executor=%s)",
|
||||
domain,
|
||||
time.perf_counter() - start,
|
||||
load_executor,
|
||||
)
|
||||
|
||||
return platform
|
||||
|
||||
def _get_platform_cached(self, full_name: str) -> ModuleType | None:
|
||||
"""Return a platform for an integration from cache."""
|
||||
cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS]
|
||||
if full_name in cache:
|
||||
return cache[full_name]
|
||||
|
||||
@@ -859,12 +952,35 @@ class Integration:
|
||||
if full_name in missing_platforms_cache:
|
||||
raise missing_platforms_cache[full_name]
|
||||
|
||||
return None
|
||||
|
||||
def get_platform(self, platform_name: str) -> ModuleType:
|
||||
"""Return a platform for an integration."""
|
||||
if platform := self._get_platform_cached(f"{self.domain}.{platform_name}"):
|
||||
return platform
|
||||
return self._load_platform(platform_name)
|
||||
|
||||
def _load_platform(self, platform_name: str) -> ModuleType:
|
||||
"""Load a platform for an integration.
|
||||
|
||||
This method must be thread-safe as its called from the executor
|
||||
and the event loop.
|
||||
|
||||
This is mostly a thin wrapper around importlib.import_module
|
||||
with a dict cache which is thread-safe since importlib has
|
||||
appropriate locks.
|
||||
"""
|
||||
full_name = f"{self.domain}.{platform_name}"
|
||||
cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS]
|
||||
try:
|
||||
cache[full_name] = self._import_platform(platform_name)
|
||||
except ImportError as ex:
|
||||
if self.domain in cache:
|
||||
# If the domain is loaded, cache that the platform
|
||||
# does not exist so we do not try to load it again
|
||||
missing_platforms_cache: dict[str, ImportError] = self.hass.data[
|
||||
DATA_MISSING_PLATFORMS
|
||||
]
|
||||
missing_platforms_cache[full_name] = ex
|
||||
raise
|
||||
except Exception as err:
|
||||
@@ -880,7 +996,11 @@ class Integration:
|
||||
return cache[full_name]
|
||||
|
||||
def _import_platform(self, platform_name: str) -> ModuleType:
|
||||
"""Import the platform."""
|
||||
"""Import the platform.
|
||||
|
||||
This method must be thread-safe as its called from the executor
|
||||
and the event loop.
|
||||
"""
|
||||
return importlib.import_module(f"{self.pkg_path}.{platform_name}")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
Reference in New Issue
Block a user