mirror of
https://github.com/home-assistant/core.git
synced 2026-07-04 05:05:38 +01:00
185 lines
6.9 KiB
Python
185 lines
6.9 KiB
Python
"""Mypy plugin: flag ``==``/``!=`` between two operands of the same enum class.
|
|
|
|
Scope is intentionally narrow: only **plain ``enum.Enum`` subclasses** are
|
|
flagged by default, because Python's ``Enum.__eq__`` is identity-based —
|
|
``a == b`` and ``a is b`` produce the same result there.
|
|
|
|
Any enum with a base outside the ``Enum`` hierarchy is **skipped**, because
|
|
such a mixin typically gives it a value-based ``__eq__``: a primitive
|
|
(``StrEnum``/``IntEnum`` or the legacy ``class X(str, Enum)`` /
|
|
``class X(int, Enum)`` / ``class X(float, Enum)`` form), a ``@dataclass``, or a
|
|
``NamedTuple``. Their ``==`` compares by value and accepts raw operands:
|
|
callers routinely pass ``"on"`` where a ``HVACMode`` parameter is annotated,
|
|
and ``==`` silently makes that work while ``is`` silently breaks it. Switching
|
|
those sites to ``is`` is a runtime-behavior change, not a refactor.
|
|
|
|
The check is deliberately conservative — it is decided from the MRO shape, not
|
|
from where ``__eq__`` is defined (mypy does not model the ``__eq__`` a
|
|
``@dataclass`` synthesizes). So an enum whose only non-``Enum`` base is a plain
|
|
helper class that does not override ``__eq__`` is still skipped even though it
|
|
is really identity-based. That under-flags such enums rather than risk an
|
|
unsafe ``==``→``is`` rewrite — the safe direction.
|
|
|
|
The ``_FRAMEWORK_GUARANTEED_ENUMS`` set carves back in
|
|
``StrEnum``/``IntEnum`` classes where the HA framework itself controls
|
|
every callsite and guarantees the value is the enum instance — currently
|
|
just ``homeassistant.data_entry_flow.FlowResultType``.
|
|
|
|
``enum.Flag``/``enum.IntFlag`` are always exempt — bitwise ``==`` is
|
|
idiomatic there.
|
|
|
|
"""
|
|
|
|
from collections.abc import Callable
|
|
|
|
from mypy.errorcodes import ErrorCode
|
|
from mypy.nodes import TypeInfo
|
|
from mypy.plugin import MethodContext, Plugin
|
|
from mypy.types import Instance, LiteralType, Type, UnionType, get_proper_type
|
|
|
|
ENUM_IDENTITY = ErrorCode(
|
|
"home-assistant-enum-identity-compare",
|
|
"Use `is`/`is not` to compare two operands of the same enum class.",
|
|
"Home Assistant",
|
|
)
|
|
|
|
_PLAIN_ENUM_BASE = "enum.Enum"
|
|
_FLAG_BASES = frozenset({"enum.Flag", "enum.IntFlag"})
|
|
|
|
|
|
def _is_value_mixin(base: TypeInfo) -> bool:
|
|
"""True if ``base`` gives an enum value-based ``__eq__``.
|
|
|
|
The plugin should fire only when ``==`` is identity-based, which holds iff
|
|
the enum has no mixin outside the ``Enum`` hierarchy. Any non-``Enum`` base
|
|
— a primitive (``str``/``int``/``float``/…), a ``@dataclass``, or a
|
|
``NamedTuple`` — provides value comparison, so ``is`` is not equivalent to
|
|
``==``. Decided structurally from the MRO (independent of how ``__eq__`` is
|
|
synthesized): a base is value-mixing unless it is ``object`` or is itself
|
|
part of the ``Enum`` hierarchy (e.g. an intermediate ``class Base(Enum)``,
|
|
which keeps the enum identity-based and therefore still flaggable).
|
|
"""
|
|
if base.fullname == "builtins.object":
|
|
return False
|
|
return not any(b.fullname == _PLAIN_ENUM_BASE for b in base.mro)
|
|
|
|
|
|
# StrEnum/IntEnum classes where every callsite assigning the value is
|
|
# framework-controlled, so the runtime value is guaranteed to be the
|
|
# enum instance (never a raw string/int). Audited additions only.
|
|
_FRAMEWORK_GUARANTEED_ENUMS = frozenset(
|
|
{
|
|
"homeassistant.data_entry_flow.FlowResultType",
|
|
}
|
|
)
|
|
|
|
|
|
def _enum_class(t: Type | None) -> TypeInfo | None:
|
|
"""Return the enum TypeInfo if t resolves to a tracked enum class.
|
|
|
|
Handles three shapes:
|
|
- ``Instance``: the direct case, e.g. ``source: SourceCodes``.
|
|
- ``LiteralType``: a single literal enum member, e.g. ``Literal[E.A]``.
|
|
Peeled to its enum-class ``fallback``.
|
|
- ``UnionType``: if all variants resolve to the same enum class, that
|
|
class is passed on.
|
|
|
|
Returns ``None`` for:
|
|
- ``Flag``/``IntFlag`` (bitwise ``==`` is idiomatic)
|
|
- value-based enums not in ``_FRAMEWORK_GUARANTEED_ENUMS``
|
|
- Anything else (``Any``, ``None``, mixed unions, etc.)
|
|
"""
|
|
if t is None:
|
|
return None
|
|
pt = get_proper_type(t)
|
|
if isinstance(pt, UnionType):
|
|
common: TypeInfo | None = None
|
|
for variant in pt.items:
|
|
v_info = _enum_class(variant)
|
|
if v_info is None:
|
|
return None
|
|
if common is None:
|
|
common = v_info
|
|
elif common.fullname != v_info.fullname:
|
|
return None
|
|
return common
|
|
if isinstance(pt, LiteralType):
|
|
pt = pt.fallback
|
|
if not isinstance(pt, Instance):
|
|
return None
|
|
info = pt.type
|
|
has_enum_base = False
|
|
has_value_based_base = False
|
|
for base in info.mro:
|
|
fn = base.fullname
|
|
if fn in _FLAG_BASES:
|
|
return None
|
|
if fn == _PLAIN_ENUM_BASE:
|
|
has_enum_base = True
|
|
continue
|
|
if _is_value_mixin(base):
|
|
has_value_based_base = True
|
|
if not has_enum_base:
|
|
return None
|
|
if has_value_based_base and info.fullname not in _FRAMEWORK_GUARANTEED_ENUMS:
|
|
# Value-based enum without explicit trust — `is` may diverge from
|
|
# `==` when callers pass the underlying primitive value.
|
|
return None
|
|
return info
|
|
|
|
|
|
def _emit(ctx: MethodContext, op: str, enum_cls: TypeInfo) -> Type:
|
|
"""Emit the warning and return the default return type."""
|
|
replacement = "is" if op == "==" else "is not"
|
|
ctx.api.fail(
|
|
f"Use `{replacement}` instead of `{op}` to compare "
|
|
f"`{enum_cls.name}` enum instances",
|
|
ctx.context,
|
|
code=ENUM_IDENTITY,
|
|
)
|
|
return ctx.default_return_type
|
|
|
|
|
|
def _make_hook(op: str) -> Callable[[MethodContext], Type]:
|
|
"""Return a method-hook callback for ``__eq__`` (``==``) or ``__ne__``."""
|
|
|
|
def hook(ctx: MethodContext) -> Type:
|
|
left_enum = _enum_class(ctx.type)
|
|
if left_enum is None:
|
|
return ctx.default_return_type
|
|
right_type = ctx.arg_types[0][0] if ctx.arg_types and ctx.arg_types[0] else None
|
|
right_enum = _enum_class(right_type)
|
|
if right_enum is None:
|
|
return ctx.default_return_type
|
|
if left_enum.fullname != right_enum.fullname:
|
|
return ctx.default_return_type
|
|
return _emit(ctx, op, left_enum)
|
|
|
|
return hook
|
|
|
|
|
|
_EQ_HOOK = _make_hook("==")
|
|
_NE_HOOK = _make_hook("!=")
|
|
|
|
|
|
class HassEnumIdentityPlugin(Plugin):
|
|
"""Mypy plugin entry point."""
|
|
|
|
def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None:
|
|
"""Return a hook for ``__eq__``/``__ne__`` calls, else ``None``.
|
|
|
|
``a == b`` desugars to ``a.__eq__(b)``; ``a != b`` to ``__ne__``.
|
|
Mypy reports the method's fullname, which we use to tell which
|
|
operator triggered the call.
|
|
"""
|
|
if fullname.endswith(".__eq__"):
|
|
return _EQ_HOOK
|
|
if fullname.endswith(".__ne__"):
|
|
return _NE_HOOK
|
|
return None
|
|
|
|
|
|
def plugin(version: str) -> type[Plugin]:
|
|
"""Mypy plugin entry point."""
|
|
return HassEnumIdentityPlugin
|