mirror of
https://github.com/home-assistant/core.git
synced 2025-12-24 12:59:34 +00:00
Prevent recursive script calls from deadlocking (#67861)
* Prevent recursive script calls from deadlocking * Address review comments, improve tests * Tweak comment
This commit is contained in:
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Callable, Sequence
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from contextvars import ContextVar
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
import itertools
|
||||
@@ -126,6 +127,8 @@ SCRIPT_BREAKPOINT_HIT = "script_breakpoint_hit"
|
||||
SCRIPT_DEBUG_CONTINUE_STOP = "script_debug_continue_stop_{}_{}"
|
||||
SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
|
||||
|
||||
script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None)
|
||||
|
||||
|
||||
def action_trace_append(variables, path):
|
||||
"""Append a TraceElement to trace[path]."""
|
||||
@@ -340,6 +343,12 @@ class _ScriptRun:
|
||||
|
||||
async def async_run(self) -> None:
|
||||
"""Run script."""
|
||||
# Push the script to the script execution stack
|
||||
if (script_stack := script_stack_cv.get()) is None:
|
||||
script_stack = []
|
||||
script_stack_cv.set(script_stack)
|
||||
script_stack.append(id(self._script))
|
||||
|
||||
try:
|
||||
self._log("Running %s", self._script.running_description)
|
||||
for self._step, self._action in enumerate(self._script.sequence):
|
||||
@@ -355,6 +364,8 @@ class _ScriptRun:
|
||||
script_execution_set("error")
|
||||
raise
|
||||
finally:
|
||||
# Pop the script from the script execution stack
|
||||
script_stack.pop()
|
||||
self._finish()
|
||||
|
||||
async def _async_step(self, log_exceptions):
|
||||
@@ -1218,6 +1229,18 @@ class Script:
|
||||
else:
|
||||
variables = cast(dict, run_variables)
|
||||
|
||||
# Prevent non-allowed recursive calls which will cause deadlocks when we try to
|
||||
# stop (restart) or wait for (queued) our own script run.
|
||||
script_stack = script_stack_cv.get()
|
||||
if (
|
||||
self.script_mode in (SCRIPT_MODE_RESTART, SCRIPT_MODE_QUEUED)
|
||||
and (script_stack := script_stack_cv.get()) is not None
|
||||
and id(self) in script_stack
|
||||
):
|
||||
script_execution_set("disallowed_recursion_detected")
|
||||
_LOGGER.warning("Disallowed recursion detected")
|
||||
return
|
||||
|
||||
if self.script_mode != SCRIPT_MODE_QUEUED:
|
||||
cls = _ScriptRun
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user