1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-24 21:06:19 +00:00

Make FlowHandler.context a typed dict (#126291)

* Make FlowHandler.context a typed dict

* Adjust typing

* Adjust typing

* Avoid calling ConfigFlowContext constructor in hot path
This commit is contained in:
Erik Montnemery
2024-10-08 12:18:45 +02:00
committed by GitHub
parent 217165208b
commit d6ee10a543
19 changed files with 175 additions and 99 deletions

View File

@@ -87,7 +87,10 @@ STEP_ID_OPTIONAL_STEPS = {
}
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult[Any]", default="FlowResult")
_FlowContextT = TypeVar("_FlowContextT", bound="FlowContext", default="FlowContext")
_FlowResultT = TypeVar(
"_FlowResultT", bound="FlowResult[Any, Any]", default="FlowResult"
)
_HandlerT = TypeVar("_HandlerT", default=str)
@@ -139,10 +142,17 @@ class AbortFlow(FlowError):
self.description_placeholders = description_placeholders
class FlowResult(TypedDict, Generic[_HandlerT], total=False):
class FlowContext(TypedDict, total=False):
"""Typed context dict."""
show_advanced_options: bool
source: str
class FlowResult(TypedDict, Generic[_FlowContextT, _HandlerT], total=False):
"""Typed result dict."""
context: dict[str, Any]
context: _FlowContextT
data_schema: vol.Schema | None
data: Mapping[str, Any]
description_placeholders: Mapping[str, str | None] | None
@@ -189,7 +199,7 @@ def _map_error_to_schema_errors(
schema_errors[path_part_str] = error.error_message
class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
"""Manage all the flows that are in progress."""
_flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment]
@@ -201,12 +211,14 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
"""Initialize the flow manager."""
self.hass = hass
self._preview: set[_HandlerT] = set()
self._progress: dict[str, FlowHandler[_FlowResultT, _HandlerT]] = {}
self._progress: dict[
str, FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
] = {}
self._handler_progress_index: defaultdict[
_HandlerT, set[FlowHandler[_FlowResultT, _HandlerT]]
_HandlerT, set[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]
] = defaultdict(set)
self._init_data_process_index: defaultdict[
type, set[FlowHandler[_FlowResultT, _HandlerT]]
type, set[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]
] = defaultdict(set)
@abc.abstractmethod
@@ -214,9 +226,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
self,
handler_key: _HandlerT,
*,
context: dict[str, Any] | None = None,
context: _FlowContextT | None = None,
data: dict[str, Any] | None = None,
) -> FlowHandler[_FlowResultT, _HandlerT]:
) -> FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]:
"""Create a flow for specified handler.
Handler key is the domain of the component that we want to set up.
@@ -224,7 +236,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
@abc.abstractmethod
async def async_finish_flow(
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
self,
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
result: _FlowResultT,
) -> _FlowResultT:
"""Finish a data entry flow.
@@ -233,7 +247,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
"""
async def async_post_init(
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
self,
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
result: _FlowResultT,
) -> None:
"""Entry has finished executing its first step asynchronously."""
@@ -288,7 +304,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
@callback
def _async_progress_by_handler(
self, handler: _HandlerT, match_context: dict[str, Any] | None
) -> list[FlowHandler[_FlowResultT, _HandlerT]]:
) -> list[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]:
"""Return the flows in progress by handler.
If match_context is specified, only return flows with a context that
@@ -307,12 +323,12 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
self,
handler: _HandlerT,
*,
context: dict[str, Any] | None = None,
context: _FlowContextT | None = None,
data: Any = None,
) -> _FlowResultT:
"""Start a data entry flow."""
if context is None:
context = {}
context = cast(_FlowContextT, {})
flow = await self.async_create_flow(handler, context=context, data=data)
if not flow:
raise UnknownFlow("Flow was not created")
@@ -452,7 +468,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
@callback
def _async_add_flow_progress(
self, flow: FlowHandler[_FlowResultT, _HandlerT]
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
) -> None:
"""Add a flow to in progress."""
if flow.init_data is not None:
@@ -462,7 +478,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
@callback
def _async_remove_flow_from_index(
self, flow: FlowHandler[_FlowResultT, _HandlerT]
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
) -> None:
"""Remove a flow from in progress."""
if flow.init_data is not None:
@@ -489,7 +505,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
async def _async_handle_step(
self,
flow: FlowHandler[_FlowResultT, _HandlerT],
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
step_id: str,
user_input: dict | BaseServiceInfo | None,
) -> _FlowResultT:
@@ -566,7 +582,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
return result
def _raise_if_step_does_not_exist(
self, flow: FlowHandler[_FlowResultT, _HandlerT], step_id: str
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT], step_id: str
) -> None:
"""Raise if the step does not exist."""
method = f"async_step_{step_id}"
@@ -578,7 +594,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
)
async def _async_setup_preview(
self, flow: FlowHandler[_FlowResultT, _HandlerT]
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
) -> None:
"""Set up preview for a flow handler."""
if flow.handler not in self._preview:
@@ -588,7 +604,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
@callback
def _async_flow_handler_to_flow_result(
self,
flows: Iterable[FlowHandler[_FlowResultT, _HandlerT]],
flows: Iterable[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]],
include_uninitialized: bool,
) -> list[_FlowResultT]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
@@ -610,7 +626,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
]
class FlowHandler(Generic[_FlowResultT, _HandlerT]):
class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]):
"""Handle a data entry flow."""
_flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment]
@@ -624,7 +640,7 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]):
hass: HomeAssistant = None # type: ignore[assignment]
handler: _HandlerT = None # type: ignore[assignment]
# Ensure the attribute has a subscriptable, but immutable, default value.
context: dict[str, Any] = MappingProxyType({}) # type: ignore[assignment]
context: _FlowContextT = MappingProxyType({}) # type: ignore[assignment]
# Set by _async_create_flow callback
init_step = "init"
@@ -643,12 +659,12 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]):
@property
def source(self) -> str | None:
"""Source that initialized the flow."""
return self.context.get("source", None) # type: ignore[no-any-return]
return self.context.get("source", None) # type: ignore[return-value]
@property
def show_advanced_options(self) -> bool:
"""If we should show advanced options."""
return self.context.get("show_advanced_options", False) # type: ignore[no-any-return]
return self.context.get("show_advanced_options", False) # type: ignore[return-value]
def add_suggested_values_to_schema(
self, data_schema: vol.Schema, suggested_values: Mapping[str, Any] | None