mirror of
https://github.com/home-assistant/core.git
synced 2025-12-24 21:06:19 +00:00
Make FlowHandler.context a typed dict (#126291)
* Make FlowHandler.context a typed dict * Adjust typing * Adjust typing * Avoid calling ConfigFlowContext constructor in hot path
This commit is contained in:
@@ -87,7 +87,10 @@ STEP_ID_OPTIONAL_STEPS = {
|
||||
}
|
||||
|
||||
|
||||
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult[Any]", default="FlowResult")
|
||||
_FlowContextT = TypeVar("_FlowContextT", bound="FlowContext", default="FlowContext")
|
||||
_FlowResultT = TypeVar(
|
||||
"_FlowResultT", bound="FlowResult[Any, Any]", default="FlowResult"
|
||||
)
|
||||
_HandlerT = TypeVar("_HandlerT", default=str)
|
||||
|
||||
|
||||
@@ -139,10 +142,17 @@ class AbortFlow(FlowError):
|
||||
self.description_placeholders = description_placeholders
|
||||
|
||||
|
||||
class FlowResult(TypedDict, Generic[_HandlerT], total=False):
|
||||
class FlowContext(TypedDict, total=False):
|
||||
"""Typed context dict."""
|
||||
|
||||
show_advanced_options: bool
|
||||
source: str
|
||||
|
||||
|
||||
class FlowResult(TypedDict, Generic[_FlowContextT, _HandlerT], total=False):
|
||||
"""Typed result dict."""
|
||||
|
||||
context: dict[str, Any]
|
||||
context: _FlowContextT
|
||||
data_schema: vol.Schema | None
|
||||
data: Mapping[str, Any]
|
||||
description_placeholders: Mapping[str, str | None] | None
|
||||
@@ -189,7 +199,7 @@ def _map_error_to_schema_errors(
|
||||
schema_errors[path_part_str] = error.error_message
|
||||
|
||||
|
||||
class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
class FlowManager(abc.ABC, Generic[_FlowContextT, _FlowResultT, _HandlerT]):
|
||||
"""Manage all the flows that are in progress."""
|
||||
|
||||
_flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment]
|
||||
@@ -201,12 +211,14 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
"""Initialize the flow manager."""
|
||||
self.hass = hass
|
||||
self._preview: set[_HandlerT] = set()
|
||||
self._progress: dict[str, FlowHandler[_FlowResultT, _HandlerT]] = {}
|
||||
self._progress: dict[
|
||||
str, FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
|
||||
] = {}
|
||||
self._handler_progress_index: defaultdict[
|
||||
_HandlerT, set[FlowHandler[_FlowResultT, _HandlerT]]
|
||||
_HandlerT, set[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]
|
||||
] = defaultdict(set)
|
||||
self._init_data_process_index: defaultdict[
|
||||
type, set[FlowHandler[_FlowResultT, _HandlerT]]
|
||||
type, set[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]
|
||||
] = defaultdict(set)
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -214,9 +226,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
self,
|
||||
handler_key: _HandlerT,
|
||||
*,
|
||||
context: dict[str, Any] | None = None,
|
||||
context: _FlowContextT | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> FlowHandler[_FlowResultT, _HandlerT]:
|
||||
) -> FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]:
|
||||
"""Create a flow for specified handler.
|
||||
|
||||
Handler key is the domain of the component that we want to set up.
|
||||
@@ -224,7 +236,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_finish_flow(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
|
||||
self,
|
||||
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
|
||||
result: _FlowResultT,
|
||||
) -> _FlowResultT:
|
||||
"""Finish a data entry flow.
|
||||
|
||||
@@ -233,7 +247,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
"""
|
||||
|
||||
async def async_post_init(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
|
||||
self,
|
||||
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
|
||||
result: _FlowResultT,
|
||||
) -> None:
|
||||
"""Entry has finished executing its first step asynchronously."""
|
||||
|
||||
@@ -288,7 +304,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
@callback
|
||||
def _async_progress_by_handler(
|
||||
self, handler: _HandlerT, match_context: dict[str, Any] | None
|
||||
) -> list[FlowHandler[_FlowResultT, _HandlerT]]:
|
||||
) -> list[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]]:
|
||||
"""Return the flows in progress by handler.
|
||||
|
||||
If match_context is specified, only return flows with a context that
|
||||
@@ -307,12 +323,12 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
self,
|
||||
handler: _HandlerT,
|
||||
*,
|
||||
context: dict[str, Any] | None = None,
|
||||
context: _FlowContextT | None = None,
|
||||
data: Any = None,
|
||||
) -> _FlowResultT:
|
||||
"""Start a data entry flow."""
|
||||
if context is None:
|
||||
context = {}
|
||||
context = cast(_FlowContextT, {})
|
||||
flow = await self.async_create_flow(handler, context=context, data=data)
|
||||
if not flow:
|
||||
raise UnknownFlow("Flow was not created")
|
||||
@@ -452,7 +468,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
|
||||
@callback
|
||||
def _async_add_flow_progress(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT]
|
||||
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
|
||||
) -> None:
|
||||
"""Add a flow to in progress."""
|
||||
if flow.init_data is not None:
|
||||
@@ -462,7 +478,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
|
||||
@callback
|
||||
def _async_remove_flow_from_index(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT]
|
||||
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
|
||||
) -> None:
|
||||
"""Remove a flow from in progress."""
|
||||
if flow.init_data is not None:
|
||||
@@ -489,7 +505,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
|
||||
async def _async_handle_step(
|
||||
self,
|
||||
flow: FlowHandler[_FlowResultT, _HandlerT],
|
||||
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
|
||||
step_id: str,
|
||||
user_input: dict | BaseServiceInfo | None,
|
||||
) -> _FlowResultT:
|
||||
@@ -566,7 +582,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
return result
|
||||
|
||||
def _raise_if_step_does_not_exist(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT], step_id: str
|
||||
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT], step_id: str
|
||||
) -> None:
|
||||
"""Raise if the step does not exist."""
|
||||
method = f"async_step_{step_id}"
|
||||
@@ -578,7 +594,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
)
|
||||
|
||||
async def _async_setup_preview(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT]
|
||||
self, flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]
|
||||
) -> None:
|
||||
"""Set up preview for a flow handler."""
|
||||
if flow.handler not in self._preview:
|
||||
@@ -588,7 +604,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
@callback
|
||||
def _async_flow_handler_to_flow_result(
|
||||
self,
|
||||
flows: Iterable[FlowHandler[_FlowResultT, _HandlerT]],
|
||||
flows: Iterable[FlowHandler[_FlowContextT, _FlowResultT, _HandlerT]],
|
||||
include_uninitialized: bool,
|
||||
) -> list[_FlowResultT]:
|
||||
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
|
||||
@@ -610,7 +626,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
]
|
||||
|
||||
|
||||
class FlowHandler(Generic[_FlowResultT, _HandlerT]):
|
||||
class FlowHandler(Generic[_FlowContextT, _FlowResultT, _HandlerT]):
|
||||
"""Handle a data entry flow."""
|
||||
|
||||
_flow_result: type[_FlowResultT] = FlowResult # type: ignore[assignment]
|
||||
@@ -624,7 +640,7 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]):
|
||||
hass: HomeAssistant = None # type: ignore[assignment]
|
||||
handler: _HandlerT = None # type: ignore[assignment]
|
||||
# Ensure the attribute has a subscriptable, but immutable, default value.
|
||||
context: dict[str, Any] = MappingProxyType({}) # type: ignore[assignment]
|
||||
context: _FlowContextT = MappingProxyType({}) # type: ignore[assignment]
|
||||
|
||||
# Set by _async_create_flow callback
|
||||
init_step = "init"
|
||||
@@ -643,12 +659,12 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]):
|
||||
@property
|
||||
def source(self) -> str | None:
|
||||
"""Source that initialized the flow."""
|
||||
return self.context.get("source", None) # type: ignore[no-any-return]
|
||||
return self.context.get("source", None) # type: ignore[return-value]
|
||||
|
||||
@property
|
||||
def show_advanced_options(self) -> bool:
|
||||
"""If we should show advanced options."""
|
||||
return self.context.get("show_advanced_options", False) # type: ignore[no-any-return]
|
||||
return self.context.get("show_advanced_options", False) # type: ignore[return-value]
|
||||
|
||||
def add_suggested_values_to_schema(
|
||||
self, data_schema: vol.Schema, suggested_values: Mapping[str, Any] | None
|
||||
|
||||
Reference in New Issue
Block a user