diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 51dcd3fa13e..beb86c1fd46 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -645,12 +645,24 @@ class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]): __progress_task: asyncio.Task[Any] | None = None __no_progress_task_reported = False deprecated_show_progress = False - _progress_step_data: ProgressStepData[_FlowResultT] = { - "tasks": {}, - "abort_reason": "", - "abort_description_placeholders": MappingProxyType({}), - "next_step_result": None, - } + __progress_step_data: ProgressStepData[_FlowResultT] | None = None + + @property + def _progress_step_data(self) -> ProgressStepData[_FlowResultT]: + """Return progress step data. + + A property is used instead of a simple attribute as derived classes + do not call super().__init__. + The property makes sure that the dict is initialized if needed. + """ + if not self.__progress_step_data: + self.__progress_step_data = { + "tasks": {}, + "abort_reason": "", + "abort_description_placeholders": MappingProxyType({}), + "next_step_result": None, + } + return self.__progress_step_data @property def source(self) -> str | None: @@ -777,9 +789,10 @@ class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]): self, user_input: dict[str, Any] | None = None ) -> _FlowResultT: """Abort the flow.""" + progress_step_data = self._progress_step_data return self.async_abort( - reason=self._progress_step_data["abort_reason"], - description_placeholders=self._progress_step_data[ + reason=progress_step_data["abort_reason"], + description_placeholders=progress_step_data[ "abort_description_placeholders" ], ) @@ -795,14 +808,15 @@ class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]): without using async_show_progress_done. If no next step is set, abort the flow. """ - if self._progress_step_data["next_step_result"] is None: + progress_step_data = self._progress_step_data + if (next_step_result := progress_step_data["next_step_result"]) is None: return self.async_abort( - reason=self._progress_step_data["abort_reason"], - description_placeholders=self._progress_step_data[ + reason=progress_step_data["abort_reason"], + description_placeholders=progress_step_data[ "abort_description_placeholders" ], ) - return self._progress_step_data["next_step_result"] + return next_step_result @callback def async_external_step( @@ -1021,9 +1035,9 @@ def progress_step[ self: FlowHandler[Any, ResultT], *args: P.args, **kwargs: P.kwargs ) -> ResultT: step_id = func.__name__.replace("async_step_", "") - + progress_step_data = self._progress_step_data # Check if we have a progress task running - progress_task = self._progress_step_data["tasks"].get(step_id) + progress_task = progress_step_data["tasks"].get(step_id) if progress_task is None: # First call - create and start the progress task @@ -1031,30 +1045,30 @@ def progress_step[ func(self, *args, **kwargs), # type: ignore[arg-type] f"Progress step {step_id}", ) - self._progress_step_data["tasks"][step_id] = progress_task + progress_step_data["tasks"][step_id] = progress_task - if not progress_task.done(): - # Handle description placeholders - placeholders = None - if description_placeholders is not None: - if callable(description_placeholders): - placeholders = description_placeholders(self) - else: - placeholders = description_placeholders + if not progress_task.done(): + # Handle description placeholders + placeholders = None + if description_placeholders is not None: + if callable(description_placeholders): + placeholders = description_placeholders(self) + else: + placeholders = description_placeholders - return self.async_show_progress( - step_id=step_id, - progress_action=step_id, - progress_task=progress_task, - description_placeholders=placeholders, - ) + return self.async_show_progress( + step_id=step_id, + progress_action=step_id, + progress_task=progress_task, + description_placeholders=placeholders, + ) # Task is done or this is a subsequent call try: - self._progress_step_data["next_step_result"] = await progress_task + progress_step_data["next_step_result"] = await progress_task except AbortFlow as err: - self._progress_step_data["abort_reason"] = err.reason - self._progress_step_data["abort_description_placeholders"] = ( + progress_step_data["abort_reason"] = err.reason + progress_step_data["abort_description_placeholders"] = ( err.description_placeholders or {} ) return self.async_show_progress_done( @@ -1062,7 +1076,7 @@ def progress_step[ ) finally: # Clean up task reference - self._progress_step_data["tasks"].pop(step_id, None) + progress_step_data["tasks"].pop(step_id, None) return self.async_show_progress_done( next_step_id="_progress_step_progress_done" diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index 55ff79e2531..b6dc2b39c7c 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -1,9 +1,11 @@ """Test the flow classes.""" import asyncio +from collections.abc import Callable import dataclasses import logging -from unittest.mock import Mock, patch +from typing import Any +from unittest.mock import AsyncMock, Mock, patch import pytest import voluptuous as vol @@ -930,6 +932,261 @@ async def test_show_progress_fires_only_when_changed( ) # change (description placeholder) +@pytest.mark.parametrize( + ("task_side_effect", "flow_result"), + [ + (None, data_entry_flow.FlowResultType.CREATE_ENTRY), + (data_entry_flow.AbortFlow("fail"), data_entry_flow.FlowResultType.ABORT), + ], +) +@pytest.mark.parametrize( + ("description", "expected_description"), + [ + (None, None), + ({"title": "World"}, {"title": "World"}), + (lambda x: {"title": "World"}, {"title": "World"}), + ], +) +async def test_progress_step( + hass: HomeAssistant, + manager: MockFlowManager, + description: Callable[[data_entry_flow.FlowHandler], dict[str, Any]] + | dict[str, Any] + | None, + expected_description: dict[str, Any] | None, + task_side_effect: Exception | None, + flow_result: data_entry_flow.FlowResultType, +) -> None: + """Test progress_step decorator.""" + manager.hass = hass + events = [] + task_init_evt = asyncio.Event() + event_received_evt = asyncio.Event() + task_result = Mock() + task_result.side_effect = task_side_effect + + @callback + def capture_events(event: Event) -> None: + events.append(event) + event_received_evt.set() + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + VERSION = 5 + + @data_entry_flow.progress_step(description_placeholders=description) + async def async_step_init(self, user_input=None): + await task_init_evt.wait() + task_result() + + return await self.async_step_finish() + + async def async_step_finish(self, user_input=None): + return self.async_create_entry(data={}) + + hass.bus.async_listen( + data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED, + capture_events, + ) + + result = await manager.async_init("test") + assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS + assert result["progress_action"] == "init" + description_placeholders = result["description_placeholders"] + assert description_placeholders == expected_description + assert len(manager.async_progress()) == 1 + assert len(manager.async_progress_by_handler("test")) == 1 + assert manager.async_get(result["flow_id"])["handler"] == "test" + + # Set task one done and wait for event + task_init_evt.set() + await event_received_evt.wait() + event_received_evt.clear() + assert len(events) == 1 + assert events[0].data == { + "handler": "test", + "flow_id": result["flow_id"], + "refresh": True, + } + + # Frontend refreshes the flow + result = await manager.async_configure(result["flow_id"]) + assert result["type"] == flow_result + + +@pytest.mark.parametrize( + ( + "task_init_side_effect", # side effect for initial step task + "task_next_side_effect", # side effect for next step task + "flow_result_before_init", # result before init task is done + "flow_result_after_init", # result after init task is done + "flow_result_after_next", # result after next task is done + "flow_init_events", # number of events fired after init task is done + "flow_next_events", # number of events fired after next task is done + "manager_call_after_init", # lambda to continue the flow after init task + "manager_call_after_next", # lambda to continue the flow after next task + "before_init_task_side_effect", # function called before init event + "before_next_task_side_effect", # function called before next event + ), + [ + ( # both steps show progress and complete successfully + None, + None, + data_entry_flow.FlowResultType.SHOW_PROGRESS, + data_entry_flow.FlowResultType.SHOW_PROGRESS, + data_entry_flow.FlowResultType.CREATE_ENTRY, + 1, + 2, + lambda manager, result: manager.async_configure(result["flow_id"]), + lambda manager, result: manager.async_configure(result["flow_id"]), + lambda received_event, init_task_event, next_task_event: None, + lambda received_event, init_task_event, next_task_event: None, + ), + ( # first step aborts + data_entry_flow.AbortFlow("fail"), + None, + data_entry_flow.FlowResultType.SHOW_PROGRESS, + data_entry_flow.FlowResultType.ABORT, + data_entry_flow.FlowResultType.ABORT, + 1, + 1, + lambda manager, result: manager.async_configure(result["flow_id"]), + lambda manager, result: AsyncMock(return_value=result)(), + lambda received_event, init_task_event, next_task_event: None, + lambda received_event, init_task_event, next_task_event: None, + ), + ( # first step shows progress, second step aborts + None, + data_entry_flow.AbortFlow("fail"), + data_entry_flow.FlowResultType.SHOW_PROGRESS, + data_entry_flow.FlowResultType.SHOW_PROGRESS, + data_entry_flow.FlowResultType.ABORT, + 1, + 2, + lambda manager, result: manager.async_configure(result["flow_id"]), + lambda manager, result: manager.async_configure(result["flow_id"]), + lambda received_event, init_task_event, next_task_event: None, + lambda received_event, init_task_event, next_task_event: None, + ), + ( # first step task is already done, second step shows progress and completes + None, + None, + data_entry_flow.FlowResultType.SHOW_PROGRESS_DONE, + data_entry_flow.FlowResultType.SHOW_PROGRESS, + data_entry_flow.FlowResultType.CREATE_ENTRY, + 0, + 1, + lambda manager, result: manager.async_configure(result["flow_id"]), + lambda manager, result: manager.async_configure(result["flow_id"]), + lambda received_event, + init_task_event, + next_task_event: received_event.set() or init_task_event.set(), + lambda received_event, init_task_event, next_task_event: None, + ), + ], +) +async def test_chaining_progress_steps( + hass: HomeAssistant, + manager: MockFlowManager, + task_init_side_effect: Exception | None, + task_next_side_effect: Exception | None, + flow_result_before_init: data_entry_flow.FlowResultType, + flow_result_after_init: data_entry_flow.FlowResultType, + flow_result_after_next: data_entry_flow.FlowResultType, + flow_init_events: int, + flow_next_events: int, + manager_call_after_init: Callable[ + [MockFlowManager, data_entry_flow.FlowResult], Any + ], + manager_call_after_next: Callable[ + [MockFlowManager, data_entry_flow.FlowResult], Any + ], + before_init_task_side_effect: Callable[ + [asyncio.Event, asyncio.Event, asyncio.Event], None + ], + before_next_task_side_effect: Callable[ + [asyncio.Event, asyncio.Event, asyncio.Event], None + ], +) -> None: + """Test chaining two steps with progress_step decorators.""" + manager.hass = hass + events = [] + event_received_evt = asyncio.Event() + task_init_evt = asyncio.Event() + task_next_evt = asyncio.Event() + task_init_result = Mock() + task_init_result.side_effect = task_init_side_effect + task_next_result = Mock() + task_next_result.side_effect = task_next_side_effect + + @callback + def capture_events(event: Event) -> None: + events.append(event) + event_received_evt.set() + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + VERSION = 5 + + def async_remove(self) -> None: + # Disable event received event to allow test to finish if flow is aborted. + event_received_evt.set() + + @data_entry_flow.progress_step() + async def async_step_init(self, user_input=None): + await task_init_evt.wait() + task_init_result() + + return await self.async_step_next() + + @data_entry_flow.progress_step() + async def async_step_next(self, user_input=None): + await task_next_evt.wait() + task_next_result() + + return await self.async_step_finish() + + async def async_step_finish(self, user_input=None): + return self.async_create_entry(data={}) + + hass.bus.async_listen( + data_entry_flow.EVENT_DATA_ENTRY_FLOW_PROGRESSED, + capture_events, + ) + + # Run side effect before first event is awaited + before_init_task_side_effect(event_received_evt, task_init_evt, task_next_evt) + + result = await manager.async_init("test") + assert result["type"] == flow_result_before_init + assert len(manager.async_progress()) == 1 + assert len(manager.async_progress_by_handler("test")) == 1 + assert manager.async_get(result["flow_id"])["handler"] == "test" + + # Set task init done and wait for event + task_init_evt.set() + await event_received_evt.wait() + event_received_evt.clear() + assert len(events) == flow_init_events + + # Run side effect before second event is awaited + before_next_task_side_effect(event_received_evt, task_init_evt, task_next_evt) + + # Continue the flow if needed. + result = await manager_call_after_init(manager, result) + assert result["type"] == flow_result_after_init + + # Set task next done and wait for event + task_next_evt.set() + await event_received_evt.wait() + event_received_evt.clear() + assert len(events) == flow_next_events + + # Continue the flow if needed. + result = await manager_call_after_next(manager, result) + assert result["type"] == flow_result_after_next + + async def test_abort_flow_exception_step(manager: MockFlowManager) -> None: """Test that the AbortFlow exception works in a step."""