1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-24 12:59:34 +00:00

Fix progress step bugs (#155923)

This commit is contained in:
Erik Montnemery
2025-11-12 13:14:53 +01:00
committed by GitHub
parent eda49cced0
commit dcc559f8b6
2 changed files with 305 additions and 34 deletions

View File

@@ -645,12 +645,24 @@ class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]):
__progress_task: asyncio.Task[Any] | None = None
__no_progress_task_reported = False
deprecated_show_progress = False
_progress_step_data: ProgressStepData[_FlowResultT] = {
"tasks": {},
"abort_reason": "",
"abort_description_placeholders": MappingProxyType({}),
"next_step_result": None,
}
__progress_step_data: ProgressStepData[_FlowResultT] | None = None
@property
def _progress_step_data(self) -> ProgressStepData[_FlowResultT]:
"""Return progress step data.
A property is used instead of a simple attribute as derived classes
do not call super().__init__.
The property makes sure that the dict is initialized if needed.
"""
if not self.__progress_step_data:
self.__progress_step_data = {
"tasks": {},
"abort_reason": "",
"abort_description_placeholders": MappingProxyType({}),
"next_step_result": None,
}
return self.__progress_step_data
@property
def source(self) -> str | None:
@@ -777,9 +789,10 @@ class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]):
self, user_input: dict[str, Any] | None = None
) -> _FlowResultT:
"""Abort the flow."""
progress_step_data = self._progress_step_data
return self.async_abort(
reason=self._progress_step_data["abort_reason"],
description_placeholders=self._progress_step_data[
reason=progress_step_data["abort_reason"],
description_placeholders=progress_step_data[
"abort_description_placeholders"
],
)
@@ -795,14 +808,15 @@ class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]):
without using async_show_progress_done.
If no next step is set, abort the flow.
"""
if self._progress_step_data["next_step_result"] is None:
progress_step_data = self._progress_step_data
if (next_step_result := progress_step_data["next_step_result"]) is None:
return self.async_abort(
reason=self._progress_step_data["abort_reason"],
description_placeholders=self._progress_step_data[
reason=progress_step_data["abort_reason"],
description_placeholders=progress_step_data[
"abort_description_placeholders"
],
)
return self._progress_step_data["next_step_result"]
return next_step_result
@callback
def async_external_step(
@@ -1021,9 +1035,9 @@ def progress_step[
self: FlowHandler[Any, ResultT], *args: P.args, **kwargs: P.kwargs
) -> ResultT:
step_id = func.__name__.replace("async_step_", "")
progress_step_data = self._progress_step_data
# Check if we have a progress task running
progress_task = self._progress_step_data["tasks"].get(step_id)
progress_task = progress_step_data["tasks"].get(step_id)
if progress_task is None:
# First call - create and start the progress task
@@ -1031,30 +1045,30 @@ def progress_step[
func(self, *args, **kwargs), # type: ignore[arg-type]
f"Progress step {step_id}",
)
self._progress_step_data["tasks"][step_id] = progress_task
progress_step_data["tasks"][step_id] = progress_task
if not progress_task.done():
# Handle description placeholders
placeholders = None
if description_placeholders is not None:
if callable(description_placeholders):
placeholders = description_placeholders(self)
else:
placeholders = description_placeholders
if not progress_task.done():
# Handle description placeholders
placeholders = None
if description_placeholders is not None:
if callable(description_placeholders):
placeholders = description_placeholders(self)
else:
placeholders = description_placeholders
return self.async_show_progress(
step_id=step_id,
progress_action=step_id,
progress_task=progress_task,
description_placeholders=placeholders,
)
return self.async_show_progress(
step_id=step_id,
progress_action=step_id,
progress_task=progress_task,
description_placeholders=placeholders,
)
# Task is done or this is a subsequent call
try:
self._progress_step_data["next_step_result"] = await progress_task
progress_step_data["next_step_result"] = await progress_task
except AbortFlow as err:
self._progress_step_data["abort_reason"] = err.reason
self._progress_step_data["abort_description_placeholders"] = (
progress_step_data["abort_reason"] = err.reason
progress_step_data["abort_description_placeholders"] = (
err.description_placeholders or {}
)
return self.async_show_progress_done(
@@ -1062,7 +1076,7 @@ def progress_step[
)
finally:
# Clean up task reference
self._progress_step_data["tasks"].pop(step_id, None)
progress_step_data["tasks"].pop(step_id, None)
return self.async_show_progress_done(
next_step_id="_progress_step_progress_done"

View File

@@ -1,9 +1,11 @@
"""Test the flow classes."""
import asyncio
from collections.abc import Callable
import dataclasses
import logging
from unittest.mock import Mock, patch
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
import pytest
import voluptuous as vol
@@ -930,6 +932,261 @@ async def test_show_progress_fires_only_when_changed(
) # change (description placeholder)
@pytest.mark.parametrize(
("task_side_effect", "flow_result"),
[
(None, data_entry_flow.FlowResultType.CREATE_ENTRY),
(data_entry_flow.AbortFlow("fail"), data_entry_flow.FlowResultType.ABORT),
],
)
@pytest.mark.parametrize(
("description", "expected_description"),
[
(None, None),
({"title": "World"}, {"title": "World"}),
(lambda x: {"title": "World"}, {"title": "World"}),
],
)
async def test_progress_step(
hass: HomeAssistant,
manager: MockFlowManager,
description: Callable[[data_entry_flow.FlowHandler], dict[str, Any]]
| dict[str, Any]
| None,
expected_description: dict[str, Any] | None,
task_side_effect: Exception | None,
flow_result: data_entry_flow.FlowResultType,
) -> None:
"""Test progress_step decorator."""
manager.hass = hass
events = []
task_init_evt = asyncio.Event()
event_received_evt = asyncio.Event()
task_result = Mock()
task_result.side_effect = task_side_effect
@callback
def capture_events(event: Event) -> None:
events.append(event)
event_received_evt.set()
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
@data_entry_flow.progress_step(description_placeholders=description)
async def async_step_init(self, user_input=None):
await task_init_evt.wait()
task_result()
return await self.async_step_finish()
async def async_step_finish(self, user_input=None):
return self.async_create_entry(data={})
hass.bus.async_listen(
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
capture_events,
)
result = await manager.async_init("test")
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "init"
description_placeholders = result["description_placeholders"]
assert description_placeholders == expected_description
assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert manager.async_get(result["flow_id"])["handler"] == "test"
# Set task one done and wait for event
task_init_evt.set()
await event_received_evt.wait()
event_received_evt.clear()
assert len(events) == 1
assert events[0].data == {
"handler": "test",
"flow_id": result["flow_id"],
"refresh": True,
}
# Frontend refreshes the flow
result = await manager.async_configure(result["flow_id"])
assert result["type"] == flow_result
@pytest.mark.parametrize(
(
"task_init_side_effect", # side effect for initial step task
"task_next_side_effect", # side effect for next step task
"flow_result_before_init", # result before init task is done
"flow_result_after_init", # result after init task is done
"flow_result_after_next", # result after next task is done
"flow_init_events", # number of events fired after init task is done
"flow_next_events", # number of events fired after next task is done
"manager_call_after_init", # lambda to continue the flow after init task
"manager_call_after_next", # lambda to continue the flow after next task
"before_init_task_side_effect", # function called before init event
"before_next_task_side_effect", # function called before next event
),
[
( # both steps show progress and complete successfully
None,
None,
data_entry_flow.FlowResultType.SHOW_PROGRESS,
data_entry_flow.FlowResultType.SHOW_PROGRESS,
data_entry_flow.FlowResultType.CREATE_ENTRY,
1,
2,
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda received_event, init_task_event, next_task_event: None,
lambda received_event, init_task_event, next_task_event: None,
),
( # first step aborts
data_entry_flow.AbortFlow("fail"),
None,
data_entry_flow.FlowResultType.SHOW_PROGRESS,
data_entry_flow.FlowResultType.ABORT,
data_entry_flow.FlowResultType.ABORT,
1,
1,
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda manager, result: AsyncMock(return_value=result)(),
lambda received_event, init_task_event, next_task_event: None,
lambda received_event, init_task_event, next_task_event: None,
),
( # first step shows progress, second step aborts
None,
data_entry_flow.AbortFlow("fail"),
data_entry_flow.FlowResultType.SHOW_PROGRESS,
data_entry_flow.FlowResultType.SHOW_PROGRESS,
data_entry_flow.FlowResultType.ABORT,
1,
2,
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda received_event, init_task_event, next_task_event: None,
lambda received_event, init_task_event, next_task_event: None,
),
( # first step task is already done, second step shows progress and completes
None,
None,
data_entry_flow.FlowResultType.SHOW_PROGRESS_DONE,
data_entry_flow.FlowResultType.SHOW_PROGRESS,
data_entry_flow.FlowResultType.CREATE_ENTRY,
0,
1,
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda manager, result: manager.async_configure(result["flow_id"]),
lambda received_event,
init_task_event,
next_task_event: received_event.set() or init_task_event.set(),
lambda received_event, init_task_event, next_task_event: None,
),
],
)
async def test_chaining_progress_steps(
hass: HomeAssistant,
manager: MockFlowManager,
task_init_side_effect: Exception | None,
task_next_side_effect: Exception | None,
flow_result_before_init: data_entry_flow.FlowResultType,
flow_result_after_init: data_entry_flow.FlowResultType,
flow_result_after_next: data_entry_flow.FlowResultType,
flow_init_events: int,
flow_next_events: int,
manager_call_after_init: Callable[
[MockFlowManager, data_entry_flow.FlowResult], Any
],
manager_call_after_next: Callable[
[MockFlowManager, data_entry_flow.FlowResult], Any
],
before_init_task_side_effect: Callable[
[asyncio.Event, asyncio.Event, asyncio.Event], None
],
before_next_task_side_effect: Callable[
[asyncio.Event, asyncio.Event, asyncio.Event], None
],
) -> None:
"""Test chaining two steps with progress_step decorators."""
manager.hass = hass
events = []
event_received_evt = asyncio.Event()
task_init_evt = asyncio.Event()
task_next_evt = asyncio.Event()
task_init_result = Mock()
task_init_result.side_effect = task_init_side_effect
task_next_result = Mock()
task_next_result.side_effect = task_next_side_effect
@callback
def capture_events(event: Event) -> None:
events.append(event)
event_received_evt.set()
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
def async_remove(self) -> None:
# Disable event received event to allow test to finish if flow is aborted.
event_received_evt.set()
@data_entry_flow.progress_step()
async def async_step_init(self, user_input=None):
await task_init_evt.wait()
task_init_result()
return await self.async_step_next()
@data_entry_flow.progress_step()
async def async_step_next(self, user_input=None):
await task_next_evt.wait()
task_next_result()
return await self.async_step_finish()
async def async_step_finish(self, user_input=None):
return self.async_create_entry(data={})
hass.bus.async_listen(
data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED,
capture_events,
)
# Run side effect before first event is awaited
before_init_task_side_effect(event_received_evt, task_init_evt, task_next_evt)
result = await manager.async_init("test")
assert result["type"] == flow_result_before_init
assert len(manager.async_progress()) == 1
assert len(manager.async_progress_by_handler("test")) == 1
assert manager.async_get(result["flow_id"])["handler"] == "test"
# Set task init done and wait for event
task_init_evt.set()
await event_received_evt.wait()
event_received_evt.clear()
assert len(events) == flow_init_events
# Run side effect before second event is awaited
before_next_task_side_effect(event_received_evt, task_init_evt, task_next_evt)
# Continue the flow if needed.
result = await manager_call_after_init(manager, result)
assert result["type"] == flow_result_after_init
# Set task next done and wait for event
task_next_evt.set()
await event_received_evt.wait()
event_received_evt.clear()
assert len(events) == flow_next_events
# Continue the flow if needed.
result = await manager_call_after_next(manager, result)
assert result["type"] == flow_result_after_next
async def test_abort_flow_exception_step(manager: MockFlowManager) -> None:
"""Test that the AbortFlow exception works in a step."""