1
0
mirror of https://github.com/home-assistant/core.git synced 2026-02-21 02:18:47 +00:00
Files

722 lines
24 KiB
Python

"""HTML5 Push Messaging notification service."""
from __future__ import annotations
from contextlib import suppress
from datetime import datetime, timedelta
from http import HTTPStatus
import json
import logging
import time
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict, cast
from urllib.parse import urlparse
import uuid
from aiohttp import ClientError, ClientResponse, ClientSession, web
from aiohttp.hdrs import AUTHORIZATION
import jwt
from py_vapid import Vapid
from pywebpush import WebPusher, WebPushException, webpush_async
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.components import websocket_api
from homeassistant.components.http import KEY_HASS, HomeAssistantView
from homeassistant.components.notify import (
ATTR_DATA,
ATTR_TARGET,
ATTR_TITLE,
ATTR_TITLE_DEFAULT,
BaseNotificationService,
NotifyEntity,
NotifyEntityFeature,
)
from homeassistant.components.websocket_api import ActiveConnection
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_NAME, URL_ROOT
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.helpers.json import save_json
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.util import ensure_unique_string
from homeassistant.util.json import load_json_object
from .const import (
ATTR_VAPID_EMAIL,
ATTR_VAPID_PRV_KEY,
ATTR_VAPID_PUB_KEY,
DOMAIN,
SERVICE_DISMISS,
)
_LOGGER = logging.getLogger(__name__)
REGISTRATIONS_FILE = "html5_push_registrations.conf"
ATTR_SUBSCRIPTION = "subscription"
ATTR_BROWSER = "browser"
ATTR_ENDPOINT = "endpoint"
ATTR_KEYS = "keys"
ATTR_AUTH = "auth"
ATTR_P256DH = "p256dh"
ATTR_EXPIRATIONTIME = "expirationTime"
ATTR_TAG = "tag"
ATTR_ACTION = "action"
ATTR_ACTIONS = "actions"
ATTR_TYPE = "type"
ATTR_URL = "url"
ATTR_DISMISS = "dismiss"
ATTR_PRIORITY = "priority"
DEFAULT_PRIORITY = "normal"
ATTR_TTL = "ttl"
DEFAULT_TTL = 86400
DEFAULT_BADGE = "/static/images/notification-badge.png"
DEFAULT_ICON = "/static/icons/favicon-192x192.png"
ATTR_JWT = "jwt"
WS_TYPE_APPKEY = "notify/html5/appkey"
SCHEMA_WS_APPKEY = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{vol.Required("type"): WS_TYPE_APPKEY}
)
# The number of days after the moment a notification is sent that a JWT
# is valid.
JWT_VALID_DAYS = 7
VAPID_CLAIM_VALID_HOURS = 12
KEYS_SCHEMA = vol.All(
dict,
vol.Schema(
{vol.Required(ATTR_AUTH): cv.string, vol.Required(ATTR_P256DH): cv.string}
),
)
SUBSCRIPTION_SCHEMA = vol.All(
dict,
vol.Schema(
{
vol.Required(ATTR_ENDPOINT): vol.Url(),
vol.Required(ATTR_KEYS): KEYS_SCHEMA,
vol.Optional(ATTR_EXPIRATIONTIME): vol.Any(None, cv.positive_int),
}
),
)
DISMISS_SERVICE_SCHEMA = vol.Schema(
{
vol.Optional(ATTR_TARGET): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(ATTR_DATA): dict,
}
)
REGISTER_SCHEMA = vol.Schema(
{
vol.Required(ATTR_SUBSCRIPTION): SUBSCRIPTION_SCHEMA,
vol.Required(ATTR_BROWSER): vol.In(["chrome", "firefox"]),
vol.Optional(ATTR_NAME): cv.string,
}
)
CALLBACK_EVENT_PAYLOAD_SCHEMA = vol.Schema(
{
vol.Required(ATTR_TAG): cv.string,
vol.Required(ATTR_TYPE): vol.In(["received", "clicked", "closed"]),
vol.Required(ATTR_TARGET): cv.string,
vol.Optional(ATTR_ACTION): cv.string,
vol.Optional(ATTR_DATA): dict,
}
)
NOTIFY_CALLBACK_EVENT = "html5_notification"
# Badge and timestamp are Chrome specific (not in official spec)
HTML5_SHOWNOTIFICATION_PARAMETERS = (
"actions",
"badge",
"body",
"dir",
"icon",
"image",
"lang",
"renotify",
"requireInteraction",
"tag",
"timestamp",
"vibrate",
"silent",
)
class Keys(TypedDict):
"""Types for keys."""
p256dh: str
auth: str
class Subscription(TypedDict):
"""Types for subscription."""
endpoint: str
expirationTime: int | None
keys: Keys
class Registration(TypedDict):
"""Types for registration."""
subscription: Subscription
browser: str
name: NotRequired[str]
async def async_get_service(
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> HTML5NotificationService | None:
"""Get the HTML5 push notification service."""
if config:
return None
if discovery_info is None:
return None
json_path = hass.config.path(REGISTRATIONS_FILE)
registrations = await hass.async_add_executor_job(_load_config, json_path)
vapid_pub_key: str = discovery_info[ATTR_VAPID_PUB_KEY]
vapid_prv_key: str = discovery_info[ATTR_VAPID_PRV_KEY]
vapid_email: str = discovery_info[ATTR_VAPID_EMAIL]
@callback
def websocket_appkey(
_hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
connection.send_message(websocket_api.result_message(msg["id"], vapid_pub_key))
websocket_api.async_register_command(
hass, WS_TYPE_APPKEY, websocket_appkey, SCHEMA_WS_APPKEY
)
hass.http.register_view(HTML5PushRegistrationView(registrations, json_path))
hass.http.register_view(HTML5PushCallbackView(registrations))
session = async_get_clientsession(hass)
return HTML5NotificationService(
hass, session, vapid_prv_key, vapid_email, registrations, json_path
)
def _load_config(filename: str) -> dict[str, Registration]:
"""Load configuration."""
with suppress(HomeAssistantError):
return cast(dict[str, Registration], load_json_object(filename))
return {}
class HTML5PushRegistrationView(HomeAssistantView):
"""Accepts push registrations from a browser."""
url = "/api/notify.html5"
name = "api:notify.html5"
def __init__(self, registrations: dict[str, Registration], json_path: str) -> None:
"""Init HTML5PushRegistrationView."""
self.registrations = registrations
self.json_path = json_path
async def post(self, request: web.Request) -> web.Response:
"""Accept the POST request for push registrations from a browser."""
try:
data: Registration = await request.json()
except ValueError:
return self.json_message("Invalid JSON", HTTPStatus.BAD_REQUEST)
try:
data = cast(Registration, REGISTER_SCHEMA(data))
except vol.Invalid as ex:
return self.json_message(humanize_error(data, ex), HTTPStatus.BAD_REQUEST)
devname = data.get(ATTR_NAME)
data.pop(ATTR_NAME, None)
name = self.find_registration_name(data, devname)
previous_registration = self.registrations.get(name)
self.registrations[name] = data
try:
hass = request.app[KEY_HASS]
await hass.async_add_executor_job(
save_json, self.json_path, self.registrations
)
return self.json_message("Push notification subscriber registered.")
except HomeAssistantError:
if previous_registration is not None:
self.registrations[name] = previous_registration
else:
self.registrations.pop(name)
return self.json_message(
"Error saving registration.", HTTPStatus.INTERNAL_SERVER_ERROR
)
def find_registration_name(
self,
data: Registration,
suggested: str | None = None,
):
"""Find a registration name matching data or generate a unique one."""
endpoint = data["subscription"]["endpoint"]
for key, registration in self.registrations.items():
subscription = registration["subscription"]
if subscription.get(ATTR_ENDPOINT) == endpoint:
return key
return ensure_unique_string(suggested or "unnamed device", self.registrations)
async def delete(self, request: web.Request):
"""Delete a registration."""
try:
data: dict[str, Any] = await request.json()
except ValueError:
return self.json_message("Invalid JSON", HTTPStatus.BAD_REQUEST)
subscription: dict[str, Any] = data[ATTR_SUBSCRIPTION]
found = None
for key, registration in self.registrations.items():
if registration["subscription"] == subscription:
found = key
break
if not found:
# If not found, unregistering was already done. Return 200
return self.json_message("Registration not found.")
reg = self.registrations.pop(found)
try:
hass = request.app[KEY_HASS]
await hass.async_add_executor_job(
save_json, self.json_path, self.registrations
)
except HomeAssistantError:
self.registrations[found] = reg
return self.json_message(
"Error saving registration.", HTTPStatus.INTERNAL_SERVER_ERROR
)
return self.json_message("Push notification subscriber unregistered.")
class HTML5PushCallbackView(HomeAssistantView):
"""Accepts push registrations from a browser."""
requires_auth = False
url = "/api/notify.html5/callback"
name = "api:notify.html5/callback"
def __init__(self, registrations: dict[str, Registration]) -> None:
"""Init HTML5PushCallbackView."""
self.registrations = registrations
def decode_jwt(self, token: str) -> web.Response | dict[str, Any]:
"""Find the registration that signed this JWT and return it."""
# 1. Check claims w/o verifying to see if a target is in there.
# 2. If target in claims, attempt to verify against the given name.
# 2a. If decode is successful, return the payload.
# 2b. If decode is unsuccessful, return a 401.
target_check: dict[str, Any] = jwt.decode(
token, algorithms=["ES256", "HS256"], options={"verify_signature": False}
)
if target_check.get(ATTR_TARGET) in self.registrations:
possible_target = self.registrations[target_check[ATTR_TARGET]]
key = possible_target["subscription"]["keys"]["auth"]
with suppress(jwt.exceptions.DecodeError):
return jwt.decode(token, key, algorithms=["ES256", "HS256"])
return self.json_message(
"No target found in JWT", status_code=HTTPStatus.UNAUTHORIZED
)
# The following is based on code from Auth0
# https://auth0.com/docs/quickstart/backend/python
def check_authorization_header(
self, request: web.Request
) -> web.Response | dict[str, Any]:
"""Check the authorization header."""
if not (auth := request.headers.get(AUTHORIZATION)):
return self.json_message(
"Authorization header is expected", status_code=HTTPStatus.UNAUTHORIZED
)
parts = auth.split()
if parts[0].lower() != "bearer":
return self.json_message(
"Authorization header must start with Bearer",
status_code=HTTPStatus.UNAUTHORIZED,
)
if len(parts) != 2:
return self.json_message(
"Authorization header must be Bearer token",
status_code=HTTPStatus.UNAUTHORIZED,
)
token = parts[1]
try:
payload = self.decode_jwt(token)
except jwt.exceptions.InvalidTokenError:
return self.json_message(
"token is invalid", status_code=HTTPStatus.UNAUTHORIZED
)
return payload
async def post(self, request: web.Request) -> web.Response:
"""Accept the POST request for push registrations event callback."""
auth_check = self.check_authorization_header(request)
if not isinstance(auth_check, dict):
return auth_check
try:
data: dict[str, str] = await request.json()
except ValueError:
return self.json_message("Invalid JSON", HTTPStatus.BAD_REQUEST)
event_payload: dict[str, Any] = {
ATTR_TAG: data.get(ATTR_TAG),
ATTR_TYPE: data[ATTR_TYPE],
ATTR_TARGET: auth_check[ATTR_TARGET],
}
if data.get(ATTR_ACTION) is not None:
event_payload[ATTR_ACTION] = data.get(ATTR_ACTION)
if data.get(ATTR_DATA) is not None:
event_payload[ATTR_DATA] = data.get(ATTR_DATA)
try:
event_payload = CALLBACK_EVENT_PAYLOAD_SCHEMA(event_payload)
except vol.Invalid as ex:
_LOGGER.warning(
"Callback event payload is not valid: %s",
humanize_error(event_payload, ex),
)
event_name = f"{NOTIFY_CALLBACK_EVENT}.{event_payload[ATTR_TYPE]}"
request.app[KEY_HASS].bus.fire(event_name, event_payload)
return self.json({"status": "ok", "event": event_payload[ATTR_TYPE]})
class HTML5NotificationService(BaseNotificationService):
"""Implement the notification service for HTML5."""
def __init__(
self,
hass: HomeAssistant,
session: ClientSession,
vapid_prv: str,
vapid_email: str,
registrations: dict[str, Registration],
json_path: str,
) -> None:
"""Initialize the service."""
self.session = session
self._vapid_prv = vapid_prv
self._vapid_email = vapid_email
self.registrations = registrations
self.registrations_json_path = json_path
async def async_dismiss_message(service: ServiceCall) -> None:
"""Handle dismissing notification message service calls."""
kwargs: dict[str, Any] = {}
if self.targets is not None:
kwargs[ATTR_TARGET] = self.targets
elif service.data.get(ATTR_TARGET) is not None:
kwargs[ATTR_TARGET] = service.data.get(ATTR_TARGET)
kwargs[ATTR_DATA] = service.data.get(ATTR_DATA)
await self.async_dismiss(**kwargs)
hass.services.async_register(
DOMAIN,
SERVICE_DISMISS,
async_dismiss_message,
schema=DISMISS_SERVICE_SCHEMA,
)
@property
def targets(self) -> dict[str, str]:
"""Return a dictionary of registered targets."""
return {registration: registration for registration in self.registrations}
async def async_dismiss(self, **kwargs: Any) -> None:
"""Dismisses a notification.
This method must be run in the event loop.
"""
data: dict[str, Any] | None = kwargs.get(ATTR_DATA)
tag: str = data.get(ATTR_TAG, "") if data else ""
payload = {ATTR_TAG: tag, ATTR_DISMISS: True, ATTR_DATA: {}}
await self._push_message(payload, **kwargs)
async def async_send_message(self, message: str = "", **kwargs: Any) -> None:
"""Send a message to a user."""
tag = str(uuid.uuid4())
payload: dict[str, Any] = {
"badge": DEFAULT_BADGE,
"body": message,
ATTR_DATA: {},
"icon": DEFAULT_ICON,
ATTR_TAG: tag,
ATTR_TITLE: kwargs.get(ATTR_TITLE, ATTR_TITLE_DEFAULT),
}
data: dict[str, Any] | None = kwargs.get(ATTR_DATA)
if data:
# Pick out fields that should go into the notification directly vs
# into the notification data dictionary.
data_tmp: dict[str, Any] = {}
for key, val in data.items():
if key in HTML5_SHOWNOTIFICATION_PARAMETERS:
payload[key] = val
else:
data_tmp[key] = val
payload[ATTR_DATA] = data_tmp
if (
payload[ATTR_DATA].get(ATTR_URL) is None
and payload.get(ATTR_ACTIONS) is None
):
payload[ATTR_DATA][ATTR_URL] = URL_ROOT
await self._push_message(payload, **kwargs)
async def _push_message(self, payload: dict[str, Any], **kwargs: Any) -> None:
"""Send the message."""
timestamp = int(time.time())
ttl = int(kwargs.get(ATTR_TTL, DEFAULT_TTL))
priority: str = kwargs.get(ATTR_PRIORITY, DEFAULT_PRIORITY)
if priority not in ["normal", "high"]:
priority = DEFAULT_PRIORITY
payload["timestamp"] = timestamp * 1000 # Javascript ms since epoch
if not (targets := kwargs.get(ATTR_TARGET)):
targets = self.registrations.keys()
for target in list(targets):
info = self.registrations.get(target)
try:
info = cast(Registration, REGISTER_SCHEMA(info))
except vol.Invalid:
_LOGGER.error(
"%s is not a valid HTML5 push notification target", target
)
continue
subscription = info["subscription"]
payload[ATTR_DATA][ATTR_JWT] = add_jwt(
timestamp,
target,
payload[ATTR_TAG],
subscription["keys"]["auth"],
)
webpusher = WebPusher(
cast(dict[str, Any], info["subscription"]), aiohttp_session=self.session
)
endpoint = urlparse(subscription["endpoint"])
vapid_claims = {
"sub": f"mailto:{self._vapid_email}",
"aud": f"{endpoint.scheme}://{endpoint.netloc}",
"exp": timestamp + (VAPID_CLAIM_VALID_HOURS * 60 * 60),
}
vapid_headers = Vapid.from_string(self._vapid_prv).sign(vapid_claims)
vapid_headers.update({"urgency": priority, "priority": priority})
response = await webpusher.send_async(
data=json.dumps(payload), headers=vapid_headers, ttl=ttl
)
if TYPE_CHECKING:
assert not isinstance(response, str)
if response.status == HTTPStatus.GONE:
_LOGGER.info("Notification channel has expired")
reg = self.registrations.pop(target)
try:
await self.hass.async_add_executor_job(
save_json, self.registrations_json_path, self.registrations
)
except HomeAssistantError:
self.registrations[target] = reg
_LOGGER.error("Error saving registration")
else:
_LOGGER.info("Configuration saved")
elif response.status >= HTTPStatus.BAD_REQUEST:
_LOGGER.error(
"There was an issue sending the notification %s: %s",
response.status,
await response.text(),
)
def add_jwt(timestamp: int, target: str, tag: str, jwt_secret: str) -> str:
"""Create JWT json to put into payload."""
jwt_exp = datetime.fromtimestamp(timestamp) + timedelta(days=JWT_VALID_DAYS)
jwt_claims = {
"exp": jwt_exp,
"nbf": timestamp,
"iat": timestamp,
ATTR_TARGET: target,
ATTR_TAG: tag,
}
return jwt.encode(jwt_claims, jwt_secret)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up the notification entity platform."""
json_path = hass.config.path(REGISTRATIONS_FILE)
registrations = await hass.async_add_executor_job(_load_config, json_path)
session = async_get_clientsession(hass)
async_add_entities(
HTML5NotifyEntity(config_entry, target, registrations, session, json_path)
for target in registrations
)
class HTML5NotifyEntity(NotifyEntity):
"""Representation of a notification entity."""
_attr_has_entity_name = True
_attr_name = None
_attr_supported_features = NotifyEntityFeature.TITLE
def __init__(
self,
config_entry: ConfigEntry,
target: str,
registrations: dict[str, Registration],
session: ClientSession,
json_path: str,
) -> None:
"""Initialize the entity."""
self.config_entry = config_entry
self.target = target
self.registrations = registrations
self.registration = registrations[target]
self.session = session
self.json_path = json_path
self._attr_unique_id = f"{config_entry.entry_id}_{target}_device"
self._attr_device_info = DeviceInfo(
entry_type=DeviceEntryType.SERVICE,
name=target,
model=self.registration["browser"].capitalize(),
identifiers={(DOMAIN, f"{config_entry.entry_id}_{target}")},
)
async def async_send_message(self, message: str, title: str | None = None) -> None:
"""Send a message to a device."""
timestamp = int(time.time())
tag = str(uuid.uuid4())
payload: dict[str, Any] = {
"badge": DEFAULT_BADGE,
"body": message,
"icon": DEFAULT_ICON,
ATTR_TAG: tag,
ATTR_TITLE: title or ATTR_TITLE_DEFAULT,
"timestamp": timestamp * 1000,
ATTR_DATA: {
ATTR_JWT: add_jwt(
timestamp,
self.target,
tag,
self.registration["subscription"]["keys"]["auth"],
)
},
}
endpoint = urlparse(self.registration["subscription"]["endpoint"])
vapid_claims = {
"sub": f"mailto:{self.config_entry.data[ATTR_VAPID_EMAIL]}",
"aud": f"{endpoint.scheme}://{endpoint.netloc}",
"exp": timestamp + (VAPID_CLAIM_VALID_HOURS * 60 * 60),
}
try:
response = await webpush_async(
cast(dict[str, Any], self.registration["subscription"]),
json.dumps(payload),
self.config_entry.data[ATTR_VAPID_PRV_KEY],
vapid_claims,
aiohttp_session=self.session,
)
cast(ClientResponse, response).raise_for_status()
except WebPushException as e:
if cast(ClientResponse, e.response).status == HTTPStatus.GONE:
reg = self.registrations.pop(self.target)
try:
await self.hass.async_add_executor_job(
save_json, self.json_path, self.registrations
)
except HomeAssistantError:
self.registrations[self.target] = reg
_LOGGER.error("Error saving registration")
self.async_write_ha_state()
raise HomeAssistantError(
translation_domain=DOMAIN,
translation_key="channel_expired",
translation_placeholders={"target": self.target},
) from e
_LOGGER.debug("Full exception", exc_info=True)
raise HomeAssistantError(
translation_domain=DOMAIN,
translation_key="request_error",
translation_placeholders={"target": self.target},
) from e
except ClientError as e:
_LOGGER.debug("Full exception", exc_info=True)
raise HomeAssistantError(
translation_domain=DOMAIN,
translation_key="connection_error",
translation_placeholders={"target": self.target},
) from e
@property
def available(self) -> bool:
"""Return True if entity is available."""
return super().available and self.target in self.registrations