1
0
mirror of https://github.com/home-assistant/core.git synced 2026-07-04 05:05:38 +01:00
Files
core/mypy_plugins/enum_identity_compare.py

185 lines
6.9 KiB
Python

"""Mypy plugin: flag ``==``/``!=`` between two operands of the same enum class.
Scope is intentionally narrow: only **plain ``enum.Enum`` subclasses** are
flagged by default, because Python's ``Enum.__eq__`` is identity-based —
``a == b`` and ``a is b`` produce the same result there.
Any enum with a base outside the ``Enum`` hierarchy is **skipped**, because
such a mixin typically gives it a value-based ``__eq__``: a primitive
(``StrEnum``/``IntEnum`` or the legacy ``class X(str, Enum)`` /
``class X(int, Enum)`` / ``class X(float, Enum)`` form), a ``@dataclass``, or a
``NamedTuple``. Their ``==`` compares by value and accepts raw operands:
callers routinely pass ``"on"`` where a ``HVACMode`` parameter is annotated,
and ``==`` silently makes that work while ``is`` silently breaks it. Switching
those sites to ``is`` is a runtime-behavior change, not a refactor.
The check is deliberately conservative — it is decided from the MRO shape, not
from where ``__eq__`` is defined (mypy does not model the ``__eq__`` a
``@dataclass`` synthesizes). So an enum whose only non-``Enum`` base is a plain
helper class that does not override ``__eq__`` is still skipped even though it
is really identity-based. That under-flags such enums rather than risk an
unsafe ``==``→``is`` rewrite — the safe direction.
The ``_FRAMEWORK_GUARANTEED_ENUMS`` set carves back in
``StrEnum``/``IntEnum`` classes where the HA framework itself controls
every callsite and guarantees the value is the enum instance — currently
just ``homeassistant.data_entry_flow.FlowResultType``.
``enum.Flag``/``enum.IntFlag`` are always exempt — bitwise ``==`` is
idiomatic there.
"""
from collections.abc import Callable
from mypy.errorcodes import ErrorCode
from mypy.nodes import TypeInfo
from mypy.plugin import MethodContext, Plugin
from mypy.types import Instance, LiteralType, Type, UnionType, get_proper_type
ENUM_IDENTITY = ErrorCode(
"home-assistant-enum-identity-compare",
"Use `is`/`is not` to compare two operands of the same enum class.",
"Home Assistant",
)
_PLAIN_ENUM_BASE = "enum.Enum"
_FLAG_BASES = frozenset({"enum.Flag", "enum.IntFlag"})
def _is_value_mixin(base: TypeInfo) -> bool:
"""True if ``base`` gives an enum value-based ``__eq__``.
The plugin should fire only when ``==`` is identity-based, which holds iff
the enum has no mixin outside the ``Enum`` hierarchy. Any non-``Enum`` base
— a primitive (``str``/``int``/``float``/…), a ``@dataclass``, or a
``NamedTuple`` — provides value comparison, so ``is`` is not equivalent to
``==``. Decided structurally from the MRO (independent of how ``__eq__`` is
synthesized): a base is value-mixing unless it is ``object`` or is itself
part of the ``Enum`` hierarchy (e.g. an intermediate ``class Base(Enum)``,
which keeps the enum identity-based and therefore still flaggable).
"""
if base.fullname == "builtins.object":
return False
return not any(b.fullname == _PLAIN_ENUM_BASE for b in base.mro)
# StrEnum/IntEnum classes where every callsite assigning the value is
# framework-controlled, so the runtime value is guaranteed to be the
# enum instance (never a raw string/int). Audited additions only.
_FRAMEWORK_GUARANTEED_ENUMS = frozenset(
{
"homeassistant.data_entry_flow.FlowResultType",
}
)
def _enum_class(t: Type | None) -> TypeInfo | None:
"""Return the enum TypeInfo if t resolves to a tracked enum class.
Handles three shapes:
- ``Instance``: the direct case, e.g. ``source: SourceCodes``.
- ``LiteralType``: a single literal enum member, e.g. ``Literal[E.A]``.
Peeled to its enum-class ``fallback``.
- ``UnionType``: if all variants resolve to the same enum class, that
class is passed on.
Returns ``None`` for:
- ``Flag``/``IntFlag`` (bitwise ``==`` is idiomatic)
- value-based enums not in ``_FRAMEWORK_GUARANTEED_ENUMS``
- Anything else (``Any``, ``None``, mixed unions, etc.)
"""
if t is None:
return None
pt = get_proper_type(t)
if isinstance(pt, UnionType):
common: TypeInfo | None = None
for variant in pt.items:
v_info = _enum_class(variant)
if v_info is None:
return None
if common is None:
common = v_info
elif common.fullname != v_info.fullname:
return None
return common
if isinstance(pt, LiteralType):
pt = pt.fallback
if not isinstance(pt, Instance):
return None
info = pt.type
has_enum_base = False
has_value_based_base = False
for base in info.mro:
fn = base.fullname
if fn in _FLAG_BASES:
return None
if fn == _PLAIN_ENUM_BASE:
has_enum_base = True
continue
if _is_value_mixin(base):
has_value_based_base = True
if not has_enum_base:
return None
if has_value_based_base and info.fullname not in _FRAMEWORK_GUARANTEED_ENUMS:
# Value-based enum without explicit trust — `is` may diverge from
# `==` when callers pass the underlying primitive value.
return None
return info
def _emit(ctx: MethodContext, op: str, enum_cls: TypeInfo) -> Type:
"""Emit the warning and return the default return type."""
replacement = "is" if op == "==" else "is not"
ctx.api.fail(
f"Use `{replacement}` instead of `{op}` to compare "
f"`{enum_cls.name}` enum instances",
ctx.context,
code=ENUM_IDENTITY,
)
return ctx.default_return_type
def _make_hook(op: str) -> Callable[[MethodContext], Type]:
"""Return a method-hook callback for ``__eq__`` (``==``) or ``__ne__``."""
def hook(ctx: MethodContext) -> Type:
left_enum = _enum_class(ctx.type)
if left_enum is None:
return ctx.default_return_type
right_type = ctx.arg_types[0][0] if ctx.arg_types and ctx.arg_types[0] else None
right_enum = _enum_class(right_type)
if right_enum is None:
return ctx.default_return_type
if left_enum.fullname != right_enum.fullname:
return ctx.default_return_type
return _emit(ctx, op, left_enum)
return hook
_EQ_HOOK = _make_hook("==")
_NE_HOOK = _make_hook("!=")
class HassEnumIdentityPlugin(Plugin):
"""Mypy plugin entry point."""
def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None:
"""Return a hook for ``__eq__``/``__ne__`` calls, else ``None``.
``a == b`` desugars to ``a.__eq__(b)``; ``a != b`` to ``__ne__``.
Mypy reports the method's fullname, which we use to tell which
operator triggered the call.
"""
if fullname.endswith(".__eq__"):
return _EQ_HOOK
if fullname.endswith(".__ne__"):
return _NE_HOOK
return None
def plugin(version: str) -> type[Plugin]:
"""Mypy plugin entry point."""
return HassEnumIdentityPlugin