mirror of
https://github.com/home-assistant/core.git
synced 2026-07-01 03:36:05 +01:00
1540 lines
58 KiB
Python
1540 lines
58 KiB
Python
"""Support for MQTT message handling."""
|
|
|
|
import asyncio
|
|
from collections import defaultdict
|
|
from collections.abc import AsyncGenerator, Callable, Coroutine, Iterable
|
|
import contextlib
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache, partial
|
|
from itertools import chain, groupby
|
|
import logging
|
|
from operator import attrgetter
|
|
import queue
|
|
import socket
|
|
import ssl
|
|
import time
|
|
from typing import TYPE_CHECKING, Any
|
|
from uuid import uuid4
|
|
|
|
import certifi
|
|
import paho.mqtt.client as mqtt
|
|
from paho.mqtt.matcher import MQTTMatcher
|
|
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.const import (
|
|
CONF_CLIENT_ID,
|
|
CONF_PASSWORD,
|
|
CONF_PORT,
|
|
CONF_PROTOCOL,
|
|
CONF_USERNAME,
|
|
EVENT_HOMEASSISTANT_STOP,
|
|
)
|
|
from homeassistant.core import (
|
|
CALLBACK_TYPE,
|
|
Event,
|
|
HassJob,
|
|
HassJobType,
|
|
HomeAssistant,
|
|
callback,
|
|
get_hassjob_callable_job_type,
|
|
)
|
|
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
|
|
from homeassistant.helpers.dispatcher import (
|
|
async_dispatcher_connect,
|
|
async_dispatcher_send,
|
|
)
|
|
from homeassistant.helpers.frame import ReportBehavior, report_usage
|
|
from homeassistant.helpers.importlib import async_import_module
|
|
from homeassistant.helpers.start import async_at_started
|
|
from homeassistant.helpers.typing import ConfigType
|
|
from homeassistant.setup import SetupPhases, async_pause_setup
|
|
from homeassistant.util.collection import chunked_or_all
|
|
from homeassistant.util.logging import catch_log_exception, log_exception
|
|
|
|
from .async_client import AsyncMQTTClient
|
|
from .const import (
|
|
CONF_BIRTH_MESSAGE,
|
|
CONF_BROKER,
|
|
CONF_CERTIFICATE,
|
|
CONF_CLIENT_CERT,
|
|
CONF_CLIENT_KEY,
|
|
CONF_KEEPALIVE,
|
|
CONF_TLS_INSECURE,
|
|
CONF_TRANSPORT,
|
|
CONF_WILL_MESSAGE,
|
|
CONF_WS_HEADERS,
|
|
CONF_WS_PATH,
|
|
DEFAULT_BIRTH,
|
|
DEFAULT_ENCODING,
|
|
DEFAULT_KEEPALIVE,
|
|
DEFAULT_PORT,
|
|
DEFAULT_QOS,
|
|
DEFAULT_TRANSPORT,
|
|
DEFAULT_WILL,
|
|
DEFAULT_WS_HEADERS,
|
|
DEFAULT_WS_PATH,
|
|
DOMAIN,
|
|
MQTT_CONNECTION_STATE,
|
|
MQTT_PROCESSED_SUBSCRIPTIONS,
|
|
PROTOCOL_5,
|
|
PROTOCOL_31,
|
|
PROTOCOL_311,
|
|
TRANSPORT_WEBSOCKETS,
|
|
)
|
|
from .models import (
|
|
DATA_MQTT,
|
|
MessageCallbackType,
|
|
MqttData,
|
|
PublishMessage,
|
|
PublishPayloadType,
|
|
ReceiveMessage,
|
|
)
|
|
from .util import EnsureJobAfterCooldown, get_file_path, mqtt_config_entry_enabled
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
MQTT_TIMEOUT = 5
|
|
|
|
MIN_BUFFER_SIZE = 131072 # Minimum buffer size to use if preferred size fails
|
|
PREFERRED_BUFFER_SIZE = 8 * 1024 * 1024 # Set receive buffer size to 8MiB
|
|
|
|
DISCOVERY_COOLDOWN = 5
|
|
# The initial subscribe cooldown controls how long to wait to group
|
|
# subscriptions together. This is to avoid making too many subscribe
|
|
# requests in a short period of time. If the number is too low, the
|
|
# system will be flooded with subscribe requests. If the number is too
|
|
# high, we risk being flooded with responses to the subscribe requests
|
|
# which can exceed the receive buffer size of the socket. To mitigate
|
|
# this, we increase the receive buffer size of the socket as well.
|
|
INITIAL_SUBSCRIBE_COOLDOWN = 0.5
|
|
SUBSCRIBE_COOLDOWN = 0.1
|
|
UNSUBSCRIBE_COOLDOWN = 0.1
|
|
TIMEOUT_ACK = 10
|
|
SUBSCRIBE_TIMEOUT = 10
|
|
RECONNECT_INTERVAL_SECONDS = 10
|
|
|
|
MAX_SUBSCRIBES_PER_CALL = 500
|
|
MAX_UNSUBSCRIBES_PER_CALL = 500
|
|
|
|
MAX_PACKETS_TO_READ = 500
|
|
|
|
type SocketType = socket.socket | ssl.SSLSocket | mqtt._WebsocketWrapper | Any # noqa: SLF001
|
|
|
|
type SubscribePayloadType = str | bytes | bytearray # Only bytes if encoding is None
|
|
|
|
|
|
def publish(
|
|
hass: HomeAssistant,
|
|
topic: str,
|
|
payload: PublishPayloadType,
|
|
qos: int = 0,
|
|
retain: bool = False,
|
|
encoding: str | None = DEFAULT_ENCODING,
|
|
*,
|
|
message_expiry_interval: int | None = None,
|
|
) -> None:
|
|
"""Publish message to a MQTT topic."""
|
|
hass.create_task(
|
|
async_publish(
|
|
hass,
|
|
topic,
|
|
payload,
|
|
qos,
|
|
retain,
|
|
encoding,
|
|
message_expiry_interval=message_expiry_interval,
|
|
)
|
|
)
|
|
|
|
|
|
async def async_publish(
|
|
hass: HomeAssistant,
|
|
topic: str,
|
|
payload: PublishPayloadType,
|
|
qos: int = 0,
|
|
retain: bool = False,
|
|
encoding: str | None = DEFAULT_ENCODING,
|
|
*,
|
|
message_expiry_interval: int | None = None,
|
|
) -> None:
|
|
"""Publish message to a MQTT topic."""
|
|
if not mqtt_config_entry_enabled(hass):
|
|
# pylint: disable-next=home-assistant-exception-message-with-translation
|
|
raise HomeAssistantError(
|
|
translation_key="mqtt_not_setup_cannot_publish",
|
|
translation_domain=DOMAIN,
|
|
translation_placeholders={"topic": topic},
|
|
)
|
|
mqtt_data = hass.data[DATA_MQTT]
|
|
outgoing_payload = payload
|
|
if not isinstance(payload, bytes) and payload is not None:
|
|
if not encoding:
|
|
_LOGGER.error(
|
|
(
|
|
"Can't pass-through payload for publishing %s on %s with no"
|
|
" encoding set, need 'bytes' got %s"
|
|
),
|
|
payload,
|
|
topic,
|
|
type(payload),
|
|
)
|
|
return
|
|
outgoing_payload = str(payload)
|
|
if encoding != DEFAULT_ENCODING:
|
|
# A string is encoded as utf-8 by default, other encoding
|
|
# requires bytes as payload
|
|
try:
|
|
outgoing_payload = outgoing_payload.encode(encoding)
|
|
except AttributeError, LookupError, UnicodeEncodeError:
|
|
_LOGGER.error(
|
|
"Can't encode payload for publishing %s on %s with encoding %s",
|
|
payload,
|
|
topic,
|
|
encoding,
|
|
)
|
|
return
|
|
|
|
# Passing None for qos or retain args was deprecated.
|
|
# Custom integrations should update there code.
|
|
# Check for fallback to `None` values can be removed with HA Core 2027.6
|
|
if qos is None or retain is None:
|
|
report_usage( # type: ignore[unreachable]
|
|
"that calls the MQTT publish API with `None` for qos or retain. "
|
|
"The `qos` argument must be an `int`, "
|
|
"and the `retain` argument must be a `bool`",
|
|
breaks_in_ha_version="2027.6.0",
|
|
core_behavior=ReportBehavior.LOG,
|
|
exclude_integrations={DOMAIN},
|
|
)
|
|
qos = qos or 0
|
|
retain = retain or False
|
|
|
|
await mqtt_data.client.async_publish(
|
|
topic,
|
|
outgoing_payload,
|
|
qos,
|
|
retain,
|
|
message_expiry_interval=message_expiry_interval,
|
|
)
|
|
|
|
|
|
@callback
|
|
def async_on_subscribe_done(
|
|
hass: HomeAssistant,
|
|
topic: str,
|
|
qos: int,
|
|
on_subscribe_status: CALLBACK_TYPE,
|
|
) -> CALLBACK_TYPE:
|
|
"""Call on_subscribe_done when the matched subscription was completed.
|
|
|
|
If a subscription is already present the callback will call
|
|
on_subscribe_status directly.
|
|
Call the returned callback to stop and cleanup status monitoring.
|
|
"""
|
|
|
|
async def _sync_mqtt_subscribe(subscriptions: list[tuple[str, int]]) -> None:
|
|
if (topic, qos) not in subscriptions:
|
|
return
|
|
hass.loop.call_soon(on_subscribe_status)
|
|
|
|
mqtt_data = hass.data[DATA_MQTT]
|
|
if (
|
|
mqtt_data.client.connected
|
|
and mqtt_data.client.is_active_subscription(topic)
|
|
and not mqtt_data.client.is_pending_subscription(topic)
|
|
):
|
|
hass.loop.call_soon(on_subscribe_status)
|
|
|
|
return async_dispatcher_connect(
|
|
hass, MQTT_PROCESSED_SUBSCRIPTIONS, _sync_mqtt_subscribe
|
|
)
|
|
|
|
|
|
async def async_subscribe(
|
|
hass: HomeAssistant,
|
|
topic: str,
|
|
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
|
qos: int = DEFAULT_QOS,
|
|
encoding: str | None = DEFAULT_ENCODING,
|
|
) -> CALLBACK_TYPE:
|
|
"""Subscribe to an MQTT topic.
|
|
|
|
Call the return value to unsubscribe.
|
|
"""
|
|
return async_subscribe_internal(hass, topic, msg_callback, qos, encoding)
|
|
|
|
|
|
@callback
|
|
def async_subscribe_internal(
|
|
hass: HomeAssistant,
|
|
topic: str,
|
|
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
|
qos: int = DEFAULT_QOS,
|
|
encoding: str | None = DEFAULT_ENCODING,
|
|
job_type: HassJobType | None = None,
|
|
) -> CALLBACK_TYPE:
|
|
"""Subscribe to an MQTT topic.
|
|
|
|
This function is internal to the MQTT integration
|
|
and may change at any time. It should not be considered
|
|
a stable API.
|
|
|
|
Call the return value to unsubscribe.
|
|
"""
|
|
try:
|
|
mqtt_data = hass.data[DATA_MQTT]
|
|
except KeyError as exc:
|
|
# pylint: disable-next=home-assistant-exception-message-with-translation
|
|
raise HomeAssistantError(
|
|
translation_key="mqtt_not_setup_cannot_subscribe",
|
|
translation_domain=DOMAIN,
|
|
translation_placeholders={"topic": topic},
|
|
) from exc
|
|
client = mqtt_data.client
|
|
if not mqtt_config_entry_enabled(hass):
|
|
# pylint: disable-next=home-assistant-exception-message-with-translation
|
|
raise HomeAssistantError(
|
|
translation_key="mqtt_not_enabled_cannot_subscribe",
|
|
translation_domain=DOMAIN,
|
|
translation_placeholders={"topic": topic},
|
|
)
|
|
return client.async_subscribe(topic, msg_callback, qos, encoding, job_type)
|
|
|
|
|
|
def subscribe(
|
|
hass: HomeAssistant,
|
|
topic: str,
|
|
msg_callback: MessageCallbackType,
|
|
qos: int = DEFAULT_QOS,
|
|
encoding: str = "utf-8",
|
|
) -> Callable[[], None]:
|
|
"""Subscribe to an MQTT topic."""
|
|
async_remove = asyncio.run_coroutine_threadsafe(
|
|
async_subscribe(hass, topic, msg_callback, qos, encoding), hass.loop
|
|
).result()
|
|
|
|
def remove() -> None:
|
|
"""Remove listener convert."""
|
|
# MQTT messages tend to be high volume, and since they
|
|
# come in via a thread and need to be processed in the
|
|
# event loop, we want to avoid hass.add_job since most
|
|
# of the time is spent calling inspect to figure out
|
|
# how to run the callback.
|
|
hass.loop.call_soon_threadsafe(async_remove)
|
|
|
|
return remove
|
|
|
|
|
|
@dataclass(slots=True, frozen=True)
|
|
class Subscription:
|
|
"""Class to hold data about an active subscription."""
|
|
|
|
topic: str
|
|
is_simple_match: bool
|
|
complex_matcher: Callable[[str], bool] | None
|
|
job: HassJob[[ReceiveMessage], Coroutine[Any, Any, None] | None]
|
|
qos: int
|
|
encoding: str | None
|
|
subscription_id: int
|
|
|
|
|
|
class MqttClientSetup:
|
|
"""Helper class to setup the paho mqtt client from config."""
|
|
|
|
_client: AsyncMQTTClient
|
|
|
|
def __init__(self, config: ConfigType) -> None:
|
|
"""Initialize the MQTT client setup helper.
|
|
|
|
self.setup must be run in an executor job.
|
|
"""
|
|
|
|
self._config = config
|
|
|
|
def setup(self) -> None:
|
|
"""Set up the MQTT client.
|
|
|
|
The setup of the MQTT client should be run in an executor job,
|
|
because it accesses files, so it does IO.
|
|
"""
|
|
config = self._config
|
|
clean_session: bool | None = None
|
|
# If no protocol setting is set in the config entry data
|
|
# we assume the config was migrated from YAML, and the
|
|
# protocol version is defaulting to legacy version 3.1.1.
|
|
if (protocol := config.get(CONF_PROTOCOL, PROTOCOL_311)) == PROTOCOL_31:
|
|
proto = mqtt.MQTTv31
|
|
clean_session = True
|
|
elif protocol == PROTOCOL_5:
|
|
proto = mqtt.MQTTv5
|
|
else:
|
|
proto = mqtt.MQTTv311
|
|
clean_session = True
|
|
|
|
if (client_id := config.get(CONF_CLIENT_ID)) is None:
|
|
# PAHO MQTT relies on the MQTT server to generate random client ID
|
|
# for protocol version 3.1, however, that feature is not mandatory
|
|
# so we generate our own.
|
|
client_id = mqtt._base62(uuid4().int, padding=22) # noqa: SLF001
|
|
transport: str = config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT)
|
|
self._client = AsyncMQTTClient(
|
|
callback_api_version=mqtt.CallbackAPIVersion.VERSION2,
|
|
client_id=client_id,
|
|
# See: https://eclipse.dev/paho/files/paho.mqtt.python/html/client.html
|
|
# clean_session (bool defaults to None)
|
|
# a boolean that determines the client type.
|
|
# If True, the broker will remove all information about this client when it
|
|
# disconnects. If False, the client is a persistent client and subscription
|
|
# information and queued messages will be retained when the client
|
|
# disconnects. Note that a client will never discard its own outgoing
|
|
# messages on disconnect. Calling connect() or reconnect() will cause the
|
|
# messages to be resent. Use reinitialise() to reset a client to its
|
|
# original state. The clean_session argument only applies to MQTT versions
|
|
# v3.1.1 and v3.1. It is not accepted if the MQTT version is v5.0 - use the
|
|
# clean_start argument on connect() instead.
|
|
clean_session=clean_session,
|
|
protocol=proto,
|
|
transport=transport, # type: ignore[arg-type]
|
|
reconnect_on_failure=False,
|
|
)
|
|
self._client.setup()
|
|
|
|
# Enable logging
|
|
self._client.enable_logger()
|
|
|
|
username: str | None = config.get(CONF_USERNAME)
|
|
password: str | None = config.get(CONF_PASSWORD)
|
|
if username is not None:
|
|
self._client.username_pw_set(username, password)
|
|
|
|
if (
|
|
certificate := get_file_path(CONF_CERTIFICATE, config.get(CONF_CERTIFICATE))
|
|
) == "auto":
|
|
certificate = certifi.where()
|
|
|
|
client_key = get_file_path(CONF_CLIENT_KEY, config.get(CONF_CLIENT_KEY))
|
|
client_cert = get_file_path(CONF_CLIENT_CERT, config.get(CONF_CLIENT_CERT))
|
|
tls_insecure = config.get(CONF_TLS_INSECURE)
|
|
if transport == TRANSPORT_WEBSOCKETS:
|
|
ws_path: str = config.get(CONF_WS_PATH, DEFAULT_WS_PATH)
|
|
ws_headers: dict[str, str] = config.get(CONF_WS_HEADERS, DEFAULT_WS_HEADERS)
|
|
self._client.ws_set_options(ws_path, ws_headers)
|
|
if certificate is not None:
|
|
self._client.tls_set(
|
|
certificate,
|
|
certfile=client_cert,
|
|
keyfile=client_key,
|
|
tls_version=ssl.PROTOCOL_TLS_CLIENT,
|
|
)
|
|
|
|
if tls_insecure is not None:
|
|
self._client.tls_insecure_set(tls_insecure)
|
|
|
|
@property
|
|
def client(self) -> AsyncMQTTClient:
|
|
"""Return the paho MQTT client."""
|
|
return self._client
|
|
|
|
|
|
def try_connection(
|
|
user_input: dict[str, Any],
|
|
) -> bool:
|
|
"""Test if we can connect to an MQTT broker."""
|
|
mqtt_client_setup = MqttClientSetup(user_input)
|
|
mqtt_client_setup.setup()
|
|
client = mqtt_client_setup.client
|
|
|
|
result: queue.Queue[bool] = queue.Queue(maxsize=1)
|
|
|
|
def on_connect(
|
|
_mqttc: mqtt.Client,
|
|
_userdata: None,
|
|
_connect_flags: mqtt.ConnectFlags,
|
|
reason_code: mqtt.ReasonCode,
|
|
_properties: mqtt.Properties | None = None,
|
|
) -> None:
|
|
"""Handle connection result."""
|
|
result.put(not reason_code.is_failure)
|
|
|
|
client.on_connect = on_connect
|
|
|
|
client.connect_async(user_input[CONF_BROKER], user_input[CONF_PORT])
|
|
client.loop_start()
|
|
|
|
try:
|
|
return result.get(timeout=MQTT_TIMEOUT)
|
|
except queue.Empty:
|
|
return False
|
|
finally:
|
|
client.disconnect()
|
|
client.loop_stop()
|
|
|
|
|
|
class MQTT:
|
|
"""Home Assistant MQTT client."""
|
|
|
|
_mqttc: AsyncMQTTClient
|
|
_last_subscribe: float
|
|
_mqtt_data: MqttData
|
|
_supports_subscription_identifiers: bool = False
|
|
|
|
def __init__(
|
|
self, hass: HomeAssistant, config_entry: ConfigEntry, conf: ConfigType
|
|
) -> None:
|
|
"""Initialize Home Assistant MQTT client."""
|
|
self.hass = hass
|
|
self.loop = hass.loop
|
|
self.config_entry = config_entry
|
|
self.conf = conf
|
|
# If no protocol setting is set in the config entry data
|
|
# we assume the config was migrated from YAML, and the
|
|
# protocol version is defaulting to legacy version 3.1.1.
|
|
self.is_mqttv5 = conf.get(CONF_PROTOCOL, PROTOCOL_311) == PROTOCOL_5
|
|
|
|
self._simple_subscriptions: defaultdict[str, set[Subscription]] = defaultdict(
|
|
set
|
|
)
|
|
# To ensure the wildcard subscriptions order is preserved, we use a dict
|
|
# with `None` values instead of a set.
|
|
self._wildcard_subscriptions: dict[Subscription, None] = {}
|
|
# _retained_topics prevents a Subscription from receiving a
|
|
# retained message more than once per topic. This prevents flooding
|
|
# already active subscribers when new subscribers subscribe to a topic
|
|
# which has subscribed messages.
|
|
self._retained_topics: defaultdict[Subscription, set[str]] = defaultdict(set)
|
|
self.connected = False
|
|
self._ha_started = asyncio.Event()
|
|
self._cleanup_on_unload: list[Callable[[], None]] = []
|
|
|
|
self._connection_lock = asyncio.Lock()
|
|
self._pending_operations: dict[int, asyncio.Future[None]] = {}
|
|
self._subscribe_debouncer = EnsureJobAfterCooldown(
|
|
INITIAL_SUBSCRIBE_COOLDOWN, self._async_perform_subscriptions
|
|
)
|
|
self._misc_timer: asyncio.TimerHandle | None = None
|
|
self._reconnect_task: asyncio.Task | None = None
|
|
self._should_reconnect: bool = True
|
|
self._available_future: asyncio.Future[bool] | None = None
|
|
|
|
self._max_qos: defaultdict[str, int] = defaultdict(int) # topic, max qos
|
|
self._pending_subscriptions: dict[str, int] = {} # topic, qos
|
|
self._unsubscribe_debouncer = EnsureJobAfterCooldown(
|
|
UNSUBSCRIBE_COOLDOWN, self._async_perform_unsubscribes
|
|
)
|
|
self._pending_unsubscribes: set[str] = set() # topic
|
|
self._cleanup_on_unload.extend(
|
|
(
|
|
async_at_started(hass, self._async_ha_started),
|
|
hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, self._async_ha_stop),
|
|
)
|
|
)
|
|
self._socket_buffersize: int | None = None
|
|
|
|
@callback
|
|
def _async_ha_started(self, _hass: HomeAssistant) -> None:
|
|
"""Handle HA started."""
|
|
self._ha_started.set()
|
|
|
|
async def _async_ha_stop(self, _event: Event) -> None:
|
|
"""Handle HA stop."""
|
|
await self.async_disconnect()
|
|
|
|
async def async_start(
|
|
self,
|
|
mqtt_data: MqttData,
|
|
) -> None:
|
|
"""Start Home Assistant MQTT client."""
|
|
self._mqtt_data = mqtt_data
|
|
await self.async_init_client()
|
|
|
|
@property
|
|
def subscriptions(self) -> set[Subscription]:
|
|
"""Return the tracked subscriptions."""
|
|
return {
|
|
*chain.from_iterable(self._simple_subscriptions.values()),
|
|
*self._wildcard_subscriptions,
|
|
}
|
|
|
|
def cleanup(self) -> None:
|
|
"""Clean up listeners."""
|
|
while self._cleanup_on_unload:
|
|
self._cleanup_on_unload.pop()()
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def _async_connect_in_executor(self) -> AsyncGenerator[None]:
|
|
# While we are connecting in the executor we need to
|
|
# handle on_socket_open and on_socket_register_write
|
|
# in the executor as well.
|
|
mqttc = self._mqttc
|
|
try:
|
|
mqttc.on_socket_open = self._on_socket_open
|
|
mqttc.on_socket_register_write = self._on_socket_register_write
|
|
yield
|
|
finally:
|
|
# Once the executor job is done, we can switch back to
|
|
# handling these in the event loop.
|
|
mqttc.on_socket_open = self._async_on_socket_open
|
|
mqttc.on_socket_register_write = self._async_on_socket_register_write
|
|
|
|
async def async_init_client(self) -> None:
|
|
"""Initialize paho client."""
|
|
with async_pause_setup(self.hass, SetupPhases.WAIT_IMPORT_PACKAGES):
|
|
await async_import_module(
|
|
self.hass, "homeassistant.components.mqtt.async_client"
|
|
)
|
|
|
|
mqttc_setup = MqttClientSetup(self.conf)
|
|
await self.hass.async_add_executor_job(mqttc_setup.setup)
|
|
mqttc = mqttc_setup.client
|
|
# on_socket_unregister_write and _async_on_socket_close
|
|
# are only ever called in the event loop
|
|
mqttc.on_socket_close = self._async_on_socket_close
|
|
mqttc.on_socket_unregister_write = self._async_on_socket_unregister_write
|
|
|
|
# These will be called in the event loop
|
|
mqttc.on_connect = self._async_mqtt_on_connect
|
|
mqttc.on_disconnect = self._async_mqtt_on_disconnect
|
|
mqttc.on_message = self._async_mqtt_on_message
|
|
mqttc.on_publish = self._async_mqtt_on_publish
|
|
mqttc.on_subscribe = self._async_mqtt_on_subscribe_unsubscribe
|
|
mqttc.on_unsubscribe = self._async_mqtt_on_subscribe_unsubscribe
|
|
|
|
# suppress exceptions at callback
|
|
mqttc.suppress_exceptions = True
|
|
|
|
if will := self.conf.get(CONF_WILL_MESSAGE, DEFAULT_WILL):
|
|
will_message = PublishMessage(**will)
|
|
mqttc.will_set(
|
|
topic=will_message.topic,
|
|
payload=will_message.payload,
|
|
qos=will_message.qos,
|
|
retain=will_message.retain,
|
|
)
|
|
|
|
self._mqttc = mqttc
|
|
|
|
@callback
|
|
def _async_reader_callback(self, client: mqtt.Client) -> None:
|
|
"""Handle reading data from the socket."""
|
|
if (status := client.loop_read(MAX_PACKETS_TO_READ)) != 0:
|
|
self._async_handle_callback_exception(status)
|
|
|
|
@callback
|
|
def _async_start_misc_periodic(self) -> None:
|
|
"""Start the misc periodic."""
|
|
assert self._misc_timer is None, "Misc periodic already started"
|
|
_LOGGER.debug("%s: Starting client misc loop", self.config_entry.title)
|
|
|
|
# Inner function to avoid having to check late import
|
|
# each time the function is called.
|
|
@callback
|
|
def _async_misc() -> None:
|
|
"""Start the MQTT client misc loop."""
|
|
if self._mqttc.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
|
|
self._misc_timer = self.loop.call_at(self.loop.time() + 1, _async_misc)
|
|
|
|
self._misc_timer = self.loop.call_at(self.loop.time() + 1, _async_misc)
|
|
|
|
def _increase_socket_buffer_size(self, sock: SocketType) -> None:
|
|
"""Increase the socket buffer size."""
|
|
if not hasattr(sock, "setsockopt") and hasattr(sock, "_socket"):
|
|
# The WebsocketWrapper does not wrap setsockopt
|
|
# so we need to get the underlying socket
|
|
# Remove this once
|
|
# https://github.com/eclipse/paho.mqtt.python/pull/843
|
|
# is available.
|
|
sock = sock._socket # noqa: SLF001
|
|
|
|
new_buffer_size = PREFERRED_BUFFER_SIZE
|
|
while True:
|
|
try:
|
|
# Some operating systems do not allow us to set the preferred
|
|
# buffer size. In that case we try some other size options.
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, new_buffer_size) # type: ignore[union-attr]
|
|
except OSError as err:
|
|
if new_buffer_size <= MIN_BUFFER_SIZE:
|
|
_LOGGER.warning(
|
|
"Unable to increase the socket buffer size to %s; "
|
|
"The connection may be unstable if the MQTT broker "
|
|
"sends data at volume or a large amount of subscriptions "
|
|
"need to be processed: %s",
|
|
new_buffer_size,
|
|
err,
|
|
)
|
|
return
|
|
new_buffer_size //= 2
|
|
else:
|
|
return
|
|
|
|
def _on_socket_open(
|
|
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
|
) -> None:
|
|
"""Handle socket open."""
|
|
self.loop.call_soon_threadsafe(
|
|
self._async_on_socket_open, client, userdata, sock
|
|
)
|
|
|
|
@callback
|
|
def _async_on_socket_open(
|
|
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
|
) -> None:
|
|
"""Handle socket open."""
|
|
fileno = sock.fileno()
|
|
_LOGGER.debug("%s: connection opened %s", self.config_entry.title, fileno)
|
|
if fileno > -1:
|
|
self._increase_socket_buffer_size(sock)
|
|
self.loop.add_reader(sock, partial(self._async_reader_callback, client))
|
|
if not self._misc_timer:
|
|
self._async_start_misc_periodic()
|
|
# Try to consume the buffer right away so it doesn't fill up
|
|
# since add_reader will wait for the next loop iteration
|
|
self._async_reader_callback(client)
|
|
|
|
@callback
|
|
def _async_on_socket_close(
|
|
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
|
) -> None:
|
|
"""Handle socket close."""
|
|
fileno = sock.fileno()
|
|
_LOGGER.debug("%s: connection closed %s", self.config_entry.title, fileno)
|
|
# If socket close is called before the connect
|
|
# result is set make sure the first connection result is set
|
|
self._async_connection_result(False)
|
|
if fileno > -1:
|
|
self.loop.remove_reader(sock)
|
|
if self._misc_timer:
|
|
self._misc_timer.cancel()
|
|
self._misc_timer = None
|
|
|
|
@callback
|
|
def _async_writer_callback(self, client: mqtt.Client) -> None:
|
|
"""Handle writing data to the socket."""
|
|
if (status := client.loop_write()) != 0:
|
|
self._async_handle_callback_exception(status)
|
|
|
|
def _on_socket_register_write(
|
|
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
|
) -> None:
|
|
"""Register the socket for writing."""
|
|
self.loop.call_soon_threadsafe(
|
|
self._async_on_socket_register_write, client, None, sock
|
|
)
|
|
|
|
@callback
|
|
def _async_on_socket_register_write(
|
|
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
|
) -> None:
|
|
"""Register the socket for writing."""
|
|
fileno = sock.fileno()
|
|
_LOGGER.debug("%s: register write %s", self.config_entry.title, fileno)
|
|
if fileno > -1:
|
|
self.loop.add_writer(sock, partial(self._async_writer_callback, client))
|
|
|
|
@callback
|
|
def _async_on_socket_unregister_write(
|
|
self, client: mqtt.Client, userdata: Any, sock: SocketType
|
|
) -> None:
|
|
"""Unregister the socket for writing."""
|
|
fileno = sock.fileno()
|
|
_LOGGER.debug("%s: unregister write %s", self.config_entry.title, fileno)
|
|
if fileno > -1:
|
|
self.loop.remove_writer(sock)
|
|
|
|
def is_active_subscription(self, topic: str) -> bool:
|
|
"""Check if a topic has an active subscription."""
|
|
return topic in self._simple_subscriptions or any(
|
|
other.topic == topic for other in self._wildcard_subscriptions
|
|
)
|
|
|
|
def is_pending_subscription(self, topic: str) -> bool:
|
|
"""Check if a topic has a pending subscription."""
|
|
return topic in self._pending_subscriptions
|
|
|
|
async def async_publish(
|
|
self,
|
|
topic: str,
|
|
payload: PublishPayloadType,
|
|
qos: int,
|
|
retain: bool,
|
|
*,
|
|
message_expiry_interval: int | None = None,
|
|
) -> None:
|
|
"""Publish a MQTT message."""
|
|
properties = mqtt.Properties(mqtt.PacketTypes.PUBLISH) # type: ignore[no-untyped-call]
|
|
if message_expiry_interval is not None:
|
|
if not self.is_mqttv5:
|
|
raise ServiceValidationError(
|
|
translation_domain=DOMAIN,
|
|
translation_key="mqtt_message_expiry_interval_not_supported",
|
|
translation_placeholders={
|
|
"topic": topic,
|
|
"protocol": self.conf.get(CONF_PROTOCOL, PROTOCOL_311),
|
|
},
|
|
)
|
|
properties.MessageExpiryInterval = message_expiry_interval
|
|
msg_info = self._mqttc.publish(topic, payload, qos, retain, properties)
|
|
_LOGGER.debug(
|
|
"Transmitting%s message on %s: '%s', mid: %s, qos: %s,"
|
|
" message_expiry_interval: %s",
|
|
" retained" if retain else "",
|
|
topic,
|
|
payload,
|
|
msg_info.mid,
|
|
qos,
|
|
message_expiry_interval,
|
|
)
|
|
await self._async_wait_for_mid_or_raise(msg_info.mid, msg_info.rc)
|
|
|
|
async def async_connect(self, client_available: asyncio.Future[bool]) -> None:
|
|
"""Connect to the host. Does not process messages yet."""
|
|
|
|
result: int | None = None
|
|
self._available_future = client_available
|
|
self._should_reconnect = True
|
|
connect_partial = partial(
|
|
self._mqttc.connect,
|
|
host=self.conf[CONF_BROKER],
|
|
port=self.conf.get(CONF_PORT, DEFAULT_PORT),
|
|
keepalive=self.conf.get(CONF_KEEPALIVE, DEFAULT_KEEPALIVE),
|
|
# See:
|
|
# https://eclipse.dev/paho/files/paho.mqtt.python/html/client.html
|
|
# `clean_start` (bool) - (MQTT v5.0 only) `True`, `False` or
|
|
# `MQTT_CLEAN_START_FIRST_ONLY`. Sets the MQTT v5.0 clean_start flag
|
|
# always, never or on the first successful connect only,
|
|
# respectively. MQTT session data (such as outstanding messages and
|
|
# subscriptions) is cleared on successful connect when the
|
|
# clean_start flag is set. For MQTT v3.1.1, the clean_session
|
|
# argument of Client should be used for similar result.
|
|
clean_start=True if self.is_mqttv5 else mqtt.MQTT_CLEAN_START_FIRST_ONLY,
|
|
)
|
|
try:
|
|
async with self._connection_lock, self._async_connect_in_executor():
|
|
result = await self.hass.async_add_executor_job(connect_partial)
|
|
except (OSError, mqtt.WebsocketConnectionError) as err:
|
|
_LOGGER.error("Failed to connect to MQTT server due to exception: %s", err)
|
|
self._async_connection_result(False)
|
|
finally:
|
|
if result is not None and result != 0:
|
|
if result is not None:
|
|
_LOGGER.error(
|
|
"Failed to connect to MQTT server: %s",
|
|
mqtt.error_string(result),
|
|
)
|
|
self._async_connection_result(False)
|
|
|
|
@callback
|
|
def _async_connection_result(self, connected: bool) -> None:
|
|
"""Handle a connection result."""
|
|
if self._available_future and not self._available_future.done():
|
|
self._available_future.set_result(connected)
|
|
|
|
if connected:
|
|
self._async_cancel_reconnect()
|
|
elif self._should_reconnect and not self._reconnect_task:
|
|
self._reconnect_task = self.config_entry.async_create_background_task(
|
|
self.hass, self._reconnect_loop(), "mqtt reconnect loop"
|
|
)
|
|
|
|
@callback
|
|
def _async_cancel_reconnect(self) -> None:
|
|
"""Cancel the reconnect task."""
|
|
if self._reconnect_task:
|
|
self._reconnect_task.cancel()
|
|
self._reconnect_task = None
|
|
|
|
async def _reconnect_loop(self) -> None:
|
|
"""Reconnect to the MQTT server."""
|
|
|
|
while True:
|
|
if not self.connected:
|
|
try:
|
|
async with self._connection_lock, self._async_connect_in_executor():
|
|
await self.hass.async_add_executor_job(self._mqttc.reconnect)
|
|
except (OSError, mqtt.WebsocketConnectionError) as err:
|
|
_LOGGER.debug(
|
|
"Error re-connecting to MQTT server due to exception: %s", err
|
|
)
|
|
|
|
await asyncio.sleep(RECONNECT_INTERVAL_SECONDS)
|
|
|
|
async def async_disconnect(self, disconnect_paho_client: bool = False) -> None:
|
|
"""Stop the MQTT client.
|
|
|
|
We only disconnect grafully if disconnect_paho_client is set, but not
|
|
when Home Assistant is shut down.
|
|
"""
|
|
|
|
# stop waiting for any pending subscriptions
|
|
await self._subscribe_debouncer.async_cleanup()
|
|
# reset timeout to initial subscribe cooldown
|
|
self._subscribe_debouncer.set_timeout(INITIAL_SUBSCRIBE_COOLDOWN)
|
|
# stop the unsubscribe debouncer
|
|
await self._unsubscribe_debouncer.async_cleanup()
|
|
# make sure the unsubscribes are processed
|
|
await self._async_perform_unsubscribes()
|
|
|
|
# wait for ACKs to be processed
|
|
if pending := self._pending_operations.values():
|
|
await asyncio.wait(pending)
|
|
|
|
# stop the MQTT loop
|
|
async with self._connection_lock:
|
|
self._should_reconnect = False
|
|
self._async_cancel_reconnect()
|
|
# We do not gracefully disconnect to ensure
|
|
# the broker publishes the will message unless the entry is reloaded
|
|
if disconnect_paho_client:
|
|
self._mqttc.disconnect()
|
|
|
|
@callback
|
|
def async_restore_tracked_subscriptions(
|
|
self, subscriptions: set[Subscription]
|
|
) -> None:
|
|
"""Restore tracked subscriptions after reload."""
|
|
for subscription in subscriptions:
|
|
subscription_id = (
|
|
1
|
|
if subscription.is_simple_match
|
|
else self._mqtt_data.subscription_id_generator.get_subscription_id(
|
|
subscription.topic
|
|
)
|
|
)
|
|
self._async_track_subscription(
|
|
Subscription(
|
|
subscription.topic,
|
|
subscription.is_simple_match,
|
|
subscription.complex_matcher,
|
|
subscription.job,
|
|
subscription.qos,
|
|
subscription.encoding,
|
|
subscription_id,
|
|
)
|
|
)
|
|
self._matching_subscriptions.cache_clear()
|
|
|
|
@callback
|
|
def _async_track_subscription(self, subscription: Subscription) -> None:
|
|
"""Track a subscription.
|
|
|
|
This method does not send a SUBSCRIBE message to the broker.
|
|
|
|
The caller is responsible clearing the cache of _matching_subscriptions.
|
|
"""
|
|
if subscription.is_simple_match:
|
|
self._simple_subscriptions[subscription.topic].add(subscription)
|
|
else:
|
|
self._wildcard_subscriptions[subscription] = None
|
|
|
|
@callback
|
|
def _async_untrack_subscription(self, subscription: Subscription) -> None:
|
|
"""Untrack a subscription.
|
|
|
|
This method does not send an UNSUBSCRIBE message to the broker.
|
|
|
|
The caller is responsible clearing the cache of _matching_subscriptions.
|
|
"""
|
|
topic = subscription.topic
|
|
try:
|
|
if subscription.is_simple_match:
|
|
simple_subscriptions = self._simple_subscriptions
|
|
simple_subscriptions[topic].remove(subscription)
|
|
if not simple_subscriptions[topic]:
|
|
del simple_subscriptions[topic]
|
|
else:
|
|
del self._wildcard_subscriptions[subscription]
|
|
except (KeyError, ValueError) as exc:
|
|
raise HomeAssistantError(
|
|
translation_domain=DOMAIN,
|
|
translation_key="mqtt_not_setup_cannot_unsubscribe_twice",
|
|
translation_placeholders={"topic": topic},
|
|
) from exc
|
|
|
|
@callback
|
|
def _async_queue_subscriptions(
|
|
self, subscriptions: Iterable[tuple[str, int]], queue_only: bool = False
|
|
) -> None:
|
|
"""Queue requested subscriptions."""
|
|
for subscription in subscriptions:
|
|
topic, qos = subscription
|
|
if (max_qos := self._max_qos[topic]) < qos:
|
|
self._max_qos[topic] = (max_qos := qos)
|
|
self._pending_subscriptions[topic] = max_qos
|
|
# Cancel any pending unsubscribe since we are subscribing now
|
|
if topic in self._pending_unsubscribes:
|
|
self._pending_unsubscribes.remove(topic)
|
|
if queue_only:
|
|
return
|
|
self._subscribe_debouncer.async_schedule()
|
|
|
|
def _exception_message(
|
|
self,
|
|
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
|
msg: ReceiveMessage,
|
|
) -> str:
|
|
"""Return a string with the exception message."""
|
|
# if msg_callback is a partial we return the name of the first argument
|
|
if isinstance(msg_callback, partial):
|
|
call_back_name = msg_callback.args[0].__name__
|
|
else:
|
|
call_back_name = msg_callback.__name__
|
|
return (
|
|
f"Exception in {call_back_name} when handling msg on "
|
|
f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe]
|
|
)
|
|
|
|
@callback
|
|
def async_subscribe(
|
|
self,
|
|
topic: str,
|
|
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
|
qos: int,
|
|
encoding: str | None = None,
|
|
job_type: HassJobType | None = None,
|
|
) -> Callable[[], None]:
|
|
"""Set up a subscription to a topic with the provided qos."""
|
|
if not isinstance(topic, str):
|
|
raise HomeAssistantError(
|
|
translation_domain=DOMAIN,
|
|
translation_key="mqtt_topic_not_a_string",
|
|
translation_placeholders={"topic": topic},
|
|
)
|
|
|
|
if job_type is None:
|
|
job_type = get_hassjob_callable_job_type(msg_callback)
|
|
if job_type is not HassJobType.Callback:
|
|
# Only wrap the callback with catch_log_exception
|
|
# if it is not a simple callback since we catch
|
|
# exceptions for simple callbacks inline for
|
|
# performance reasons.
|
|
msg_callback = catch_log_exception(
|
|
msg_callback, partial(self._exception_message, msg_callback)
|
|
)
|
|
|
|
job = HassJob(msg_callback, job_type=job_type)
|
|
is_simple_match = not ("+" in topic or "#" in topic)
|
|
matcher = None if is_simple_match else _matcher_for_topic(topic)
|
|
|
|
if is_simple_match:
|
|
subscription_id = 1
|
|
else:
|
|
subscription_id = self._mqtt_data.subscription_id_generator.get_or_generate(
|
|
topic
|
|
)
|
|
|
|
subscription = Subscription(
|
|
topic, is_simple_match, matcher, job, qos, encoding, subscription_id
|
|
)
|
|
|
|
self._async_track_subscription(subscription)
|
|
self._matching_subscriptions.cache_clear()
|
|
|
|
# Only subscribe if currently connected.
|
|
if self.connected:
|
|
self._async_queue_subscriptions(((topic, qos),))
|
|
|
|
return partial(self._async_remove, subscription)
|
|
|
|
@callback
|
|
def _async_remove(self, subscription: Subscription) -> None:
|
|
"""Remove subscription."""
|
|
self._async_untrack_subscription(subscription)
|
|
self._matching_subscriptions.cache_clear()
|
|
if subscription in self._retained_topics:
|
|
del self._retained_topics[subscription]
|
|
# Only unsubscribe if currently connected
|
|
if self.connected:
|
|
self._async_unsubscribe(subscription.topic, subscription.subscription_id)
|
|
|
|
@callback
|
|
def _async_unsubscribe(self, topic: str, subscription_id: int) -> None:
|
|
"""Unsubscribe from a topic."""
|
|
if self.is_active_subscription(topic):
|
|
if self._max_qos[topic] == 0:
|
|
return
|
|
subs = self._matching_subscriptions(topic, (subscription_id,))
|
|
self._max_qos[topic] = max(sub.qos for sub in subs)
|
|
# Other subscriptions on topic remaining - don't unsubscribe.
|
|
return
|
|
if topic in self._max_qos:
|
|
del self._max_qos[topic]
|
|
if topic in self._pending_subscriptions:
|
|
# Avoid any pending subscription to be executed
|
|
del self._pending_subscriptions[topic]
|
|
|
|
self._pending_unsubscribes.add(topic)
|
|
self._unsubscribe_debouncer.async_schedule()
|
|
|
|
async def _async_perform_subscriptions(self) -> None:
|
|
"""Perform MQTT client subscriptions."""
|
|
# Section 3.3.1.3 in the specification:
|
|
# http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html
|
|
# When sending a PUBLISH Packet to a Client the Server MUST
|
|
# set the RETAIN flag to 1 if a message is sent as a result of a
|
|
# new subscription being made by a Client [MQTT-3.3.1-8].
|
|
# It MUST set the RETAIN flag to 0 when a PUBLISH Packet is sent to
|
|
# a Client because it matches an established subscription regardless
|
|
# of how the flag was set in the message it received [MQTT-3.3.1-9].
|
|
#
|
|
# Since we do not know if a published value is retained we need to
|
|
# (re)subscribe, to ensure retained messages are replayed
|
|
if not self._pending_subscriptions:
|
|
return
|
|
|
|
# Split out the wildcard subscriptions, we subscribe to them one by one
|
|
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
|
pending_subscriptions: dict[str, int] = self._pending_subscriptions
|
|
pending_wildcard_subscriptions = {
|
|
subscription.topic: pending_subscriptions.pop(subscription.topic)
|
|
for subscription in self._wildcard_subscriptions
|
|
if subscription.topic in pending_subscriptions
|
|
}
|
|
subscribe_chain = chunked_or_all(
|
|
pending_subscriptions.items(), MAX_SUBSCRIBES_PER_CALL
|
|
)
|
|
if self._supports_subscription_identifiers and pending_subscriptions:
|
|
bulk_properties = mqtt.Properties(packetType=mqtt.PacketTypes.SUBSCRIBE) # type: ignore[no-untyped-call]
|
|
bulk_properties.SubscriptionIdentifier = 1
|
|
else:
|
|
bulk_properties = None
|
|
|
|
self._pending_subscriptions = {}
|
|
|
|
for topic, qos in pending_wildcard_subscriptions.items():
|
|
if self._supports_subscription_identifiers:
|
|
properties = mqtt.Properties(packetType=mqtt.PacketTypes.SUBSCRIBE) # type: ignore[no-untyped-call]
|
|
properties.SubscriptionIdentifier = (
|
|
self._mqtt_data.subscription_id_generator.get_subscription_id(topic)
|
|
)
|
|
else:
|
|
properties = None
|
|
|
|
result, mid = self._mqttc.subscribe(topic, qos, properties=properties)
|
|
if debug_enabled:
|
|
_LOGGER.debug(
|
|
"Subscribing with mid: %s to topic %s "
|
|
"with qos: %s and properties: %s",
|
|
mid,
|
|
topic,
|
|
qos,
|
|
properties,
|
|
)
|
|
self._last_subscribe = time.monotonic()
|
|
|
|
await self._async_wait_for_mid_or_raise(mid, result)
|
|
async_dispatcher_send(
|
|
self.hass, MQTT_PROCESSED_SUBSCRIPTIONS, [(topic, qos)]
|
|
)
|
|
|
|
for chunk in subscribe_chain:
|
|
chunk_list = list(chunk)
|
|
if not chunk_list:
|
|
continue
|
|
|
|
result, mid = self._mqttc.subscribe(chunk_list, properties=bulk_properties)
|
|
|
|
if debug_enabled:
|
|
_LOGGER.debug(
|
|
"Subscribing with mid: %s to topics with qos: %s", mid, chunk_list
|
|
)
|
|
self._last_subscribe = time.monotonic()
|
|
|
|
await self._async_wait_for_mid_or_raise(mid, result)
|
|
async_dispatcher_send(self.hass, MQTT_PROCESSED_SUBSCRIPTIONS, chunk_list)
|
|
|
|
async def _async_perform_unsubscribes(self) -> None:
|
|
"""Perform pending MQTT client unsubscribes."""
|
|
if not self._pending_unsubscribes:
|
|
return
|
|
|
|
topics = list(self._pending_unsubscribes)
|
|
self._pending_unsubscribes = set()
|
|
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
|
|
|
|
for chunk in chunked_or_all(topics, MAX_UNSUBSCRIBES_PER_CALL):
|
|
chunk_list = list(chunk)
|
|
|
|
result, mid = self._mqttc.unsubscribe(chunk_list)
|
|
if debug_enabled:
|
|
_LOGGER.debug(
|
|
"Unsubscribing with mid: %s to topics: %s", mid, chunk_list
|
|
)
|
|
|
|
await self._async_wait_for_mid_or_raise(mid, result)
|
|
|
|
# Remove stored subscription identifiers for topics that were just unsubscribed
|
|
for topic in topics:
|
|
self._mqtt_data.subscription_id_generator.release(topic)
|
|
|
|
async def _async_resubscribe_and_publish_birth_message(
|
|
self, birth_message: PublishMessage
|
|
) -> None:
|
|
"""Resubscribe to all topics and publish birth message."""
|
|
self._async_queue_resubscribe()
|
|
self._subscribe_debouncer.async_schedule()
|
|
await self._ha_started.wait() # Wait for Home Assistant to start
|
|
await self._discovery_cooldown() # Wait for MQTT discovery to cool down
|
|
# Update subscribe cooldown period to a shorter time
|
|
# and make sure we flush the debouncer
|
|
await self._subscribe_debouncer.async_execute()
|
|
self._subscribe_debouncer.set_timeout(SUBSCRIBE_COOLDOWN)
|
|
await self.async_publish(
|
|
topic=birth_message.topic,
|
|
payload=birth_message.payload,
|
|
qos=birth_message.qos,
|
|
retain=birth_message.retain,
|
|
)
|
|
_LOGGER.info("MQTT client initialized, birth message sent")
|
|
|
|
@callback
|
|
def _async_mqtt_on_connect(
|
|
self,
|
|
_mqttc: mqtt.Client,
|
|
_userdata: None,
|
|
_connect_flags: mqtt.ConnectFlags,
|
|
reason_code: mqtt.ReasonCode,
|
|
properties: mqtt.Properties | None = None,
|
|
) -> None:
|
|
"""On connect callback.
|
|
|
|
Resubscribe to all topics we were subscribed to and publish birth
|
|
message.
|
|
"""
|
|
if self.is_mqttv5:
|
|
# Check if the server explicitly disabled Subscription Identifiers
|
|
if (
|
|
properties is not None
|
|
and hasattr(properties, "SubscriptionIdentifierAvailable")
|
|
and properties.SubscriptionIdentifierAvailable == 0
|
|
):
|
|
_LOGGER.warning(
|
|
"Your MQTT broker reports it does not support "
|
|
"Subscription Identifiers, see "
|
|
"https://docs.oasis-open.org/"
|
|
"mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901092. "
|
|
"Please use a supported MQTT broker; got broker properties: %s",
|
|
properties,
|
|
)
|
|
self._supports_subscription_identifiers = False
|
|
else:
|
|
self._supports_subscription_identifiers = True
|
|
else:
|
|
self._supports_subscription_identifiers = False
|
|
|
|
if reason_code.is_failure:
|
|
# 24: Continue authentication
|
|
# 25: Re-authenticate
|
|
# 134: Bad user name or password
|
|
# 135: Not authorized
|
|
# 140: Bad authentication method
|
|
if reason_code.value in (24, 25, 134, 135, 140):
|
|
self._should_reconnect = False
|
|
self.hass.async_create_task(self.async_disconnect())
|
|
self.config_entry.async_start_reauth(self.hass)
|
|
_LOGGER.error(
|
|
"Unable to connect to the MQTT broker: %s",
|
|
reason_code.getName(), # type: ignore[no-untyped-call]
|
|
)
|
|
self._async_connection_result(False)
|
|
return
|
|
|
|
self.connected = True
|
|
async_dispatcher_send(self.hass, MQTT_CONNECTION_STATE, True)
|
|
_LOGGER.debug(
|
|
"Connected to MQTT server %s:%s (%s)",
|
|
self.conf[CONF_BROKER],
|
|
self.conf.get(CONF_PORT, DEFAULT_PORT),
|
|
reason_code,
|
|
)
|
|
|
|
birth: dict[str, Any]
|
|
if birth := self.conf.get(CONF_BIRTH_MESSAGE, DEFAULT_BIRTH):
|
|
birth_message = PublishMessage(**birth)
|
|
self.config_entry.async_create_background_task(
|
|
self.hass,
|
|
self._async_resubscribe_and_publish_birth_message(birth_message),
|
|
name="mqtt re-subscribe and birth",
|
|
)
|
|
else:
|
|
# Update subscribe cooldown period to a shorter time
|
|
self._async_queue_resubscribe()
|
|
self._subscribe_debouncer.async_schedule()
|
|
|
|
self._async_connection_result(True)
|
|
|
|
@callback
|
|
def _async_queue_resubscribe(self) -> None:
|
|
"""Queue subscriptions on reconnect.
|
|
|
|
self._async_perform_subscriptions must be called
|
|
after this method to actually subscribe.
|
|
"""
|
|
self._max_qos.clear()
|
|
self._retained_topics.clear()
|
|
# Group subscriptions to only re-subscribe once for each topic.
|
|
keyfunc = attrgetter("topic")
|
|
self._async_queue_subscriptions(
|
|
[
|
|
# Re-subscribe with the highest requested qos
|
|
(topic, max(subscription.qos for subscription in subs))
|
|
for topic, subs in groupby(
|
|
sorted(self.subscriptions, key=keyfunc), keyfunc
|
|
)
|
|
],
|
|
queue_only=True,
|
|
)
|
|
|
|
@lru_cache(None) # pylint: disable=method-cache-max-size-none
|
|
def _matching_subscriptions(
|
|
self, topic: str, identifiers: tuple[int, ...] | None
|
|
) -> list[Subscription]:
|
|
subscriptions: list[Subscription] = []
|
|
if topic in self._simple_subscriptions and (
|
|
identifiers is None or 1 in identifiers
|
|
):
|
|
# The subscription identifier is always 1 for simple subscriptions,
|
|
# so only include them when no identifiers are provided or 1 matches.
|
|
subscriptions.extend(self._simple_subscriptions[topic])
|
|
subscriptions.extend(
|
|
subscription
|
|
for subscription in self._wildcard_subscriptions
|
|
# mypy doesn't know that complex_matcher is always set when
|
|
# is_simple_match is False
|
|
if subscription.complex_matcher(topic) # type: ignore[misc]
|
|
and (identifiers is None or subscription.subscription_id in identifiers)
|
|
)
|
|
return subscriptions
|
|
|
|
@callback
|
|
def _async_mqtt_on_message(
|
|
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
|
|
) -> None:
|
|
identifiers: tuple[int, ...] | None = None
|
|
if self._supports_subscription_identifiers:
|
|
# It is possible we have multiple messages if there
|
|
# are overlapping wildcard subscriptions.
|
|
# So we assigned all wildcard subscriptions with a
|
|
# unique SubscriptionIdentifier. Simple subscriptions are assigned
|
|
# with SubscriptionIdentifier 1.
|
|
if TYPE_CHECKING:
|
|
assert msg.properties is not None
|
|
assert hasattr(msg.properties, "SubscriptionIdentifier")
|
|
with contextlib.suppress(AttributeError):
|
|
identifiers = tuple(msg.properties.SubscriptionIdentifier)
|
|
try:
|
|
# msg.topic is a property that decodes the topic to a string
|
|
# every time it is accessed. Save the result to avoid
|
|
# decoding the same topic multiple times.
|
|
topic = msg.topic
|
|
except UnicodeDecodeError:
|
|
bare_topic: bytes = msg._topic # noqa: SLF001
|
|
_LOGGER.warning(
|
|
"Skipping received%s message on invalid topic %s (qos=%s): %s",
|
|
" retained" if msg.retain else "",
|
|
bare_topic,
|
|
msg.qos,
|
|
msg.payload[0:8192],
|
|
)
|
|
return
|
|
_LOGGER.debug(
|
|
"Received%s message on %s (qos=%s) IDs=%s: %s",
|
|
" retained" if msg.retain else "",
|
|
topic,
|
|
msg.qos,
|
|
identifiers,
|
|
msg.payload[0:8192],
|
|
)
|
|
msg_cache_by_subscription_topic: dict[str, ReceiveMessage] = {}
|
|
|
|
for subscription in self._matching_subscriptions(topic, identifiers):
|
|
if msg.retain:
|
|
retained_topics = self._retained_topics[subscription]
|
|
# Skip if the subscription already received a retained message
|
|
if topic in retained_topics:
|
|
continue
|
|
# Remember the subscription had an initial retained message
|
|
self._retained_topics[subscription].add(topic)
|
|
|
|
payload: SubscribePayloadType = msg.payload
|
|
if subscription.encoding is not None:
|
|
try:
|
|
payload = msg.payload.decode(subscription.encoding)
|
|
except AttributeError, UnicodeDecodeError:
|
|
_LOGGER.warning(
|
|
"Can't decode payload %s on %s with encoding %s (for %s)",
|
|
msg.payload[0:8192],
|
|
topic,
|
|
subscription.encoding,
|
|
subscription.job,
|
|
)
|
|
continue
|
|
subscription_topic = subscription.topic
|
|
if subscription_topic not in msg_cache_by_subscription_topic:
|
|
# Only make one copy of the message
|
|
# per topic so we avoid storing a separate
|
|
# dataclass in memory for each subscriber
|
|
# to the same topic for retained messages
|
|
receive_msg = ReceiveMessage(
|
|
topic,
|
|
payload,
|
|
msg.qos,
|
|
msg.retain,
|
|
subscription_topic,
|
|
msg.timestamp,
|
|
)
|
|
msg_cache_by_subscription_topic[subscription_topic] = receive_msg
|
|
else:
|
|
receive_msg = msg_cache_by_subscription_topic[subscription_topic]
|
|
job = subscription.job
|
|
if job.job_type is HassJobType.Callback:
|
|
# We do not wrap Callback jobs in catch_log_exception since
|
|
# its expensive and we have to do it 2x for every entity
|
|
try:
|
|
job.target(receive_msg)
|
|
except Exception: # noqa: BLE001
|
|
log_exception(
|
|
partial(self._exception_message, job.target, receive_msg)
|
|
)
|
|
else:
|
|
self.hass.async_run_hass_job(job, receive_msg)
|
|
self._mqtt_data.state_write_requests.process_write_state_requests(msg)
|
|
|
|
@callback
|
|
def _async_mqtt_on_publish(
|
|
self,
|
|
_mqttc: mqtt.Client,
|
|
_userdata: None,
|
|
mid: int,
|
|
_reason_code: mqtt.ReasonCode,
|
|
_properties: mqtt.Properties | None,
|
|
) -> None:
|
|
"""Publish callback."""
|
|
self._async_mqtt_on_callback(mid)
|
|
|
|
@callback
|
|
def _async_mqtt_on_subscribe_unsubscribe(
|
|
self,
|
|
_mqttc: mqtt.Client,
|
|
_userdata: None,
|
|
mid: int,
|
|
_reason_code: list[mqtt.ReasonCode],
|
|
_properties: mqtt.Properties | None,
|
|
) -> None:
|
|
"""Subscribe / Unsubscribe callback."""
|
|
self._async_mqtt_on_callback(mid)
|
|
|
|
@callback
|
|
def _async_mqtt_on_callback(self, mid: int) -> None:
|
|
"""Publish / Subscribe / Unsubscribe callback."""
|
|
future = self._async_get_mid_future(mid)
|
|
if future.done() and (future.cancelled() or future.exception()):
|
|
# Timed out or cancelled
|
|
return
|
|
future.set_result(None)
|
|
|
|
@callback
|
|
def _async_get_mid_future(self, mid: int) -> asyncio.Future[None]:
|
|
"""Get the future for a mid."""
|
|
if future := self._pending_operations.get(mid):
|
|
return future
|
|
future = self.hass.loop.create_future()
|
|
self._pending_operations[mid] = future
|
|
return future
|
|
|
|
@callback
|
|
def _async_handle_callback_exception(self, status: mqtt.MQTTErrorCode) -> None:
|
|
"""Handle a callback exception."""
|
|
|
|
_LOGGER.warning(
|
|
"Error returned from MQTT server: %s",
|
|
mqtt.error_string(status),
|
|
)
|
|
|
|
@callback
|
|
def _async_mqtt_on_disconnect(
|
|
self,
|
|
_mqttc: mqtt.Client,
|
|
_userdata: None,
|
|
_disconnect_flags: mqtt.DisconnectFlags,
|
|
reason_code: mqtt.ReasonCode,
|
|
properties: mqtt.Properties | None = None,
|
|
) -> None:
|
|
"""Disconnected callback."""
|
|
if not self.connected:
|
|
# This function is re-entrant and may be called multiple times
|
|
# when there is a broken pipe error.
|
|
return
|
|
# If disconnect is called before the connect
|
|
# result is set make sure the first connection result is set
|
|
self._async_connection_result(False)
|
|
self.connected = False
|
|
async_dispatcher_send(self.hass, MQTT_CONNECTION_STATE, False)
|
|
_LOGGER.log(
|
|
logging.INFO if reason_code == 0 else logging.DEBUG,
|
|
"Disconnected from MQTT server %s:%s (%s)",
|
|
self.conf[CONF_BROKER],
|
|
self.conf.get(CONF_PORT, DEFAULT_PORT),
|
|
reason_code,
|
|
)
|
|
|
|
@callback
|
|
def _async_timeout_mid(self, future: asyncio.Future[None]) -> None:
|
|
"""Timeout waiting for a mid."""
|
|
if not future.done():
|
|
future.set_exception(asyncio.TimeoutError)
|
|
|
|
async def _async_wait_for_mid_or_raise(
|
|
self, mid: int | None, result_code: int
|
|
) -> None:
|
|
"""Wait for ACK from broker or raise on error."""
|
|
if result_code != 0:
|
|
raise HomeAssistantError(
|
|
translation_domain=DOMAIN,
|
|
translation_key="mqtt_broker_error",
|
|
translation_placeholders={
|
|
"error_message": mqtt.error_string(result_code)
|
|
},
|
|
)
|
|
|
|
# Create the mid event if not created, either _mqtt_handle_mid or
|
|
# _async_wait_for_mid_or_raise may be executed first.
|
|
if TYPE_CHECKING:
|
|
assert mid is not None
|
|
future = self._async_get_mid_future(mid)
|
|
loop = self.hass.loop
|
|
timer_handle = loop.call_later(TIMEOUT_ACK, self._async_timeout_mid, future)
|
|
try:
|
|
await future
|
|
except TimeoutError:
|
|
_LOGGER.warning(
|
|
"No ACK from MQTT server in %s seconds (mid: %s)", TIMEOUT_ACK, mid
|
|
)
|
|
finally:
|
|
timer_handle.cancel()
|
|
del self._pending_operations[mid]
|
|
|
|
async def _discovery_cooldown(self) -> None:
|
|
"""Wait until all discovery and subscriptions are processed."""
|
|
now = time.monotonic()
|
|
# Reset discovery and subscribe cooldowns
|
|
self._mqtt_data.last_discovery = now
|
|
self._last_subscribe = now
|
|
|
|
last_discovery = self._mqtt_data.last_discovery
|
|
last_subscribe = now if self._pending_subscriptions else self._last_subscribe
|
|
wait_until = max(last_discovery, last_subscribe) + DISCOVERY_COOLDOWN
|
|
while now < wait_until:
|
|
await asyncio.sleep(wait_until - now)
|
|
now = time.monotonic()
|
|
last_discovery = self._mqtt_data.last_discovery
|
|
last_subscribe = (
|
|
now if self._pending_subscriptions else self._last_subscribe
|
|
)
|
|
wait_until = max(last_discovery, last_subscribe) + DISCOVERY_COOLDOWN
|
|
|
|
|
|
def _matcher_for_topic(subscription: str) -> Callable[[str], bool]:
|
|
matcher = MQTTMatcher() # type: ignore[no-untyped-call]
|
|
matcher[subscription] = True
|
|
|
|
return lambda topic: next(matcher.iter_match(topic), False) # type: ignore[no-untyped-call]
|