mirror of
https://github.com/home-assistant/core.git
synced 2026-02-15 07:36:16 +00:00
Misc typing improvements (#153322)
This commit is contained in:
@@ -2,9 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TypeVar
|
||||
|
||||
T = TypeVar("T", dict[str, Any], list[Any], None)
|
||||
from typing import Any
|
||||
|
||||
TRANSLATION_MAP = {
|
||||
"wan_rx": "sensor_rx_bytes",
|
||||
@@ -36,7 +34,7 @@ def clean_dict(raw: dict[str, Any]) -> dict[str, Any]:
|
||||
return {k: v for k, v in raw.items() if v is not None or k.endswith("state")}
|
||||
|
||||
|
||||
def translate_to_legacy(raw: T) -> T:
|
||||
def translate_to_legacy[T: (dict[str, Any], list[Any], None)](raw: T) -> T:
|
||||
"""Translate raw data to legacy format for dicts and lists."""
|
||||
|
||||
if raw is None:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Coordinator module for managing Growatt data fetching."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
@@ -145,7 +147,7 @@ class GrowattCoordinator(DataUpdateCoordinator[dict[str, Any]]):
|
||||
return self.data.get("currency")
|
||||
|
||||
def get_data(
|
||||
self, entity_description: "GrowattSensorEntityDescription"
|
||||
self, entity_description: GrowattSensorEntityDescription
|
||||
) -> str | int | float | None:
|
||||
"""Get the data."""
|
||||
variable = entity_description.api_key
|
||||
|
||||
@@ -11,7 +11,7 @@ from homeassistant.core import HomeAssistant
|
||||
from .const import CONF_DEVICE_DATA, CONF_DEVICE_TYPE
|
||||
from .coordinator import INKBIRDActiveBluetoothProcessorCoordinator
|
||||
|
||||
INKBIRDConfigEntry = ConfigEntry[INKBIRDActiveBluetoothProcessorCoordinator]
|
||||
type INKBIRDConfigEntry = ConfigEntry[INKBIRDActiveBluetoothProcessorCoordinator]
|
||||
|
||||
PLATFORMS: list[Platform] = [Platform.SENSOR]
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||
UPDATE_INTERVAL = datetime.timedelta(minutes=30)
|
||||
TIMEOUT = 10
|
||||
|
||||
TokenManager = Callable[[], Awaitable[str]]
|
||||
type TokenManager = Callable[[], Awaitable[str]]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
import re
|
||||
from typing import Generic, TypeVar, cast
|
||||
from typing import cast
|
||||
|
||||
from qbusmqttapi.discovery import QbusMqttDevice, QbusMqttOutput
|
||||
from qbusmqttapi.factory import QbusMqttMessageFactory, QbusMqttTopicFactory
|
||||
@@ -20,8 +20,6 @@ from .coordinator import QbusControllerCoordinator
|
||||
|
||||
_REFID_REGEX = re.compile(r"^\d+\/(\d+(?:\/\d+)?)$")
|
||||
|
||||
StateT = TypeVar("StateT", bound=QbusMqttState)
|
||||
|
||||
|
||||
def create_new_entities(
|
||||
coordinator: QbusControllerCoordinator,
|
||||
@@ -78,7 +76,7 @@ def create_unique_id(serial_number: str, suffix: str) -> str:
|
||||
return f"ctd_{serial_number}_{suffix}"
|
||||
|
||||
|
||||
class QbusEntity(Entity, ABC, Generic[StateT]):
|
||||
class QbusEntity[StateT: QbusMqttState](Entity, ABC):
|
||||
"""Representation of a Qbus entity."""
|
||||
|
||||
_state_cls: type[StateT] = cast(type[StateT], QbusMqttState)
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from typing import TypeVar
|
||||
|
||||
from bleak.exc import BleakError
|
||||
from togrill_bluetooth.client import Client
|
||||
@@ -39,8 +38,6 @@ type ToGrillConfigEntry = ConfigEntry[ToGrillCoordinator]
|
||||
SCAN_INTERVAL = timedelta(seconds=30)
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PacketType = TypeVar("PacketType", bound=Packet)
|
||||
|
||||
|
||||
def get_version_string(packet: PacketA0Notify) -> str:
|
||||
"""Construct a version string from packet data."""
|
||||
@@ -179,9 +176,9 @@ class ToGrillCoordinator(DataUpdateCoordinator[dict[tuple[int, int | None], Pack
|
||||
self.client = await self._connect_and_update_registry()
|
||||
return self.client
|
||||
|
||||
def get_packet(
|
||||
self, packet_type: type[PacketType], probe=None
|
||||
) -> PacketType | None:
|
||||
def get_packet[PacketT: Packet](
|
||||
self, packet_type: type[PacketT], probe=None
|
||||
) -> PacketT | None:
|
||||
"""Get a cached packet of a certain type."""
|
||||
|
||||
if packet := self.data.get((packet_type.type, probe)):
|
||||
|
||||
@@ -16,7 +16,7 @@ from pathlib import Path
|
||||
import re
|
||||
import secrets
|
||||
from time import monotonic
|
||||
from typing import Any, Final, Generic, Protocol, TypeVar
|
||||
from typing import Any, Final, Protocol
|
||||
|
||||
from aiohttp import web
|
||||
import mutagen
|
||||
@@ -628,10 +628,7 @@ class HasLastUsed(Protocol):
|
||||
last_used: float
|
||||
|
||||
|
||||
T = TypeVar("T", bound=HasLastUsed)
|
||||
|
||||
|
||||
class DictCleaning(Generic[T]):
|
||||
class DictCleaning[T: HasLastUsed]:
|
||||
"""Helper to clean up the stale sessions."""
|
||||
|
||||
unsub: CALLBACK_TYPE | None = None
|
||||
|
||||
@@ -8,7 +8,7 @@ from collections.abc import Callable, Coroutine
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from typing import Any, Generic, TypeVar, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from volvocarsapi.api import VolvoCarsApi
|
||||
from volvocarsapi.models import (
|
||||
@@ -64,10 +64,7 @@ def _is_invalid_api_field(field: VolvoCarsApiBaseModel | None) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
T = TypeVar("T", bound=dict, default=dict[str, Any])
|
||||
|
||||
|
||||
class VolvoBaseCoordinator(DataUpdateCoordinator[T], Generic[T]):
|
||||
class VolvoBaseCoordinator[T: dict = dict[str, Any]](DataUpdateCoordinator[T]):
|
||||
"""Volvo base coordinator."""
|
||||
|
||||
config_entry: VolvoConfigEntry
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import timedelta
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from aiohttp import ClientResponseError
|
||||
from weatherflow4py.api import WeatherFlowRestAPI
|
||||
@@ -29,10 +28,8 @@ from homeassistant.util.ssl import client_context
|
||||
|
||||
from .const import DOMAIN, LOGGER
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseWeatherFlowCoordinator(DataUpdateCoordinator[dict[int, T]], ABC, Generic[T]):
|
||||
class BaseWeatherFlowCoordinator[T](DataUpdateCoordinator[dict[int, T]], ABC):
|
||||
"""Base class for WeatherFlow coordinators."""
|
||||
|
||||
def __init__(
|
||||
@@ -106,9 +103,7 @@ class WeatherFlowCloudUpdateCoordinatorREST(
|
||||
return self.data[station_id].station.name
|
||||
|
||||
|
||||
class BaseWebsocketCoordinator(
|
||||
BaseWeatherFlowCoordinator[dict[int, T | None]], ABC, Generic[T]
|
||||
):
|
||||
class BaseWebsocketCoordinator[T](BaseWeatherFlowCoordinator[dict[int, T | None]], ABC):
|
||||
"""Base class for websocket coordinators."""
|
||||
|
||||
_event_type: EventType
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Helpers functions for the Workday component."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, timedelta
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -20,7 +22,7 @@ from .const import CONF_REMOVE_HOLIDAYS, DOMAIN, LOGGER
|
||||
|
||||
async def async_validate_country_and_province(
|
||||
hass: HomeAssistant,
|
||||
entry: "WorkdayConfigEntry",
|
||||
entry: WorkdayConfigEntry,
|
||||
country: str | None,
|
||||
province: str | None,
|
||||
) -> None:
|
||||
@@ -180,7 +182,7 @@ def get_holidays_object(
|
||||
|
||||
def add_remove_custom_holidays(
|
||||
hass: HomeAssistant,
|
||||
entry: "WorkdayConfigEntry",
|
||||
entry: WorkdayConfigEntry,
|
||||
country: str | None,
|
||||
calc_add_holidays: list[DateLike],
|
||||
calc_remove_holidays: list[str],
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest
|
||||
class MockAirzoneCloudApi(AirzoneCloudApi):
|
||||
"""Mock AirzoneCloudApi class."""
|
||||
|
||||
async def mock_update(self: "AirzoneCloudApi"):
|
||||
async def mock_update(self):
|
||||
"""Mock AirzoneCloudApi _update function."""
|
||||
await self.update_polling()
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ def mock_get_forecast_api_error():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_irm_kmi_api(request: pytest.FixtureRequest) -> Generator[None, MagicMock]:
|
||||
def mock_irm_kmi_api(request: pytest.FixtureRequest) -> Generator[MagicMock]:
|
||||
"""Return a mocked IrmKmi api client."""
|
||||
fixture: str = "forecast.json"
|
||||
|
||||
@@ -111,9 +111,7 @@ def mock_irm_kmi_api_high_low_temp():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_exception_irm_kmi_api(
|
||||
request: pytest.FixtureRequest,
|
||||
) -> Generator[None, MagicMock]:
|
||||
def mock_exception_irm_kmi_api(request: pytest.FixtureRequest) -> Generator[MagicMock]:
|
||||
"""Return a mocked IrmKmi api client that will raise an error upon refreshing data."""
|
||||
with patch(
|
||||
"homeassistant.components.irm_kmi.IrmKmiApiClientHa", autospec=True
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for the Twitch component."""
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from twitchAPI.object.base import TwitchObject
|
||||
|
||||
@@ -20,10 +20,7 @@ async def setup_integration(hass: HomeAssistant, config_entry: MockConfigEntry)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
TwitchType = TypeVar("TwitchType", bound=TwitchObject)
|
||||
|
||||
|
||||
class TwitchIterObject(Generic[TwitchType]):
|
||||
class TwitchIterObject[TwitchT: TwitchObject]:
|
||||
"""Twitch object iterator."""
|
||||
|
||||
raw_data: JsonArrayType
|
||||
@@ -31,14 +28,14 @@ class TwitchIterObject(Generic[TwitchType]):
|
||||
total: int
|
||||
|
||||
def __init__(
|
||||
self, hass: HomeAssistant, fixture: str, target_type: type[TwitchType]
|
||||
self, hass: HomeAssistant, fixture: str, target_type: type[TwitchT]
|
||||
) -> None:
|
||||
"""Initialize object."""
|
||||
self.hass = hass
|
||||
self.fixture = fixture
|
||||
self.target_type = target_type
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[TwitchType]:
|
||||
async def __aiter__(self) -> AsyncIterator[TwitchT]:
|
||||
"""Return async iterator."""
|
||||
if not hasattr(self, "raw_data"):
|
||||
self.raw_data = await async_load_json_array_fixture(
|
||||
@@ -50,18 +47,18 @@ class TwitchIterObject(Generic[TwitchType]):
|
||||
yield item
|
||||
|
||||
|
||||
async def get_generator(
|
||||
hass: HomeAssistant, fixture: str, target_type: type[TwitchType]
|
||||
) -> AsyncGenerator[TwitchType]:
|
||||
async def get_generator[TwitchT: TwitchObject](
|
||||
hass: HomeAssistant, fixture: str, target_type: type[TwitchT]
|
||||
) -> AsyncGenerator[TwitchT]:
|
||||
"""Return async generator."""
|
||||
data = await async_load_json_array_fixture(hass, fixture, DOMAIN)
|
||||
async for item in get_generator_from_data(data, target_type):
|
||||
yield item
|
||||
|
||||
|
||||
async def get_generator_from_data(
|
||||
items: list[dict[str, Any]], target_type: type[TwitchType]
|
||||
) -> AsyncGenerator[TwitchType]:
|
||||
async def get_generator_from_data[TwitchT: TwitchObject](
|
||||
items: list[dict[str, Any]], target_type: type[TwitchT]
|
||||
) -> AsyncGenerator[TwitchT]:
|
||||
"""Return async generator."""
|
||||
for item in items:
|
||||
yield target_type(**item)
|
||||
|
||||
@@ -41,7 +41,7 @@ HUB_DATA = {
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_vegehub() -> Generator[Any, Any, Any]:
|
||||
def mock_vegehub() -> Generator[Any]:
|
||||
"""Mock the VegeHub library."""
|
||||
with patch(
|
||||
"homeassistant.components.vegehub.config_flow.VegeHub", autospec=True
|
||||
|
||||
@@ -40,7 +40,7 @@ DISCOVERY_INFO = zeroconf.ZeroconfServiceInfo(
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_setup_entry() -> Generator[Any, Any, Any]:
|
||||
def mock_setup_entry() -> Generator[Any]:
|
||||
"""Prevent the actual integration from being set up."""
|
||||
with (
|
||||
patch("homeassistant.components.vegehub.async_setup_entry", return_value=True),
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Test Zeroconf multiple instance protection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Self
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -20,7 +23,7 @@ class MockZeroconf:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
"""Initialize the mock."""
|
||||
|
||||
def __new__(cls, *args, **kwargs) -> "MockZeroconf":
|
||||
def __new__(cls, *args, **kwargs) -> Self:
|
||||
"""Return the shared instance."""
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user