mirror of
https://github.com/home-assistant/core.git
synced 2025-12-25 05:26:47 +00:00
Implement websocket message coalescing (#77238)
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
@@ -2,15 +2,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Callable, Generator
|
||||
from contextlib import asynccontextmanager
|
||||
import functools
|
||||
from json import JSONDecoder, loads
|
||||
import logging
|
||||
import ssl
|
||||
import threading
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
from aiohttp.test_utils import make_mocked_request
|
||||
from aiohttp import client
|
||||
from aiohttp.pytest_plugin import AiohttpClient
|
||||
from aiohttp.test_utils import (
|
||||
BaseTestServer,
|
||||
TestClient,
|
||||
TestServer,
|
||||
make_mocked_request,
|
||||
)
|
||||
from aiohttp.web import Application
|
||||
import freezegun
|
||||
import multidict
|
||||
import pytest
|
||||
@@ -57,6 +67,7 @@ from tests.components.recorder.common import ( # noqa: E402, isort:skip
|
||||
async_recorder_block_till_done,
|
||||
)
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -203,6 +214,97 @@ def load_registries():
|
||||
return True
|
||||
|
||||
|
||||
class CoalescingResponse(client.ClientWebSocketResponse):
|
||||
"""ClientWebSocketResponse client that mimics the websocket js code."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Init the ClientWebSocketResponse."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._recv_buffer: list[Any] = []
|
||||
|
||||
async def receive_json(
|
||||
self,
|
||||
*,
|
||||
loads: JSONDecoder = loads,
|
||||
timeout: float | None = None,
|
||||
) -> Any:
|
||||
"""receive_json or from buffer."""
|
||||
if self._recv_buffer:
|
||||
return self._recv_buffer.pop(0)
|
||||
data = await self.receive_str(timeout=timeout)
|
||||
decoded = loads(data)
|
||||
if isinstance(decoded, list):
|
||||
self._recv_buffer = decoded
|
||||
return self._recv_buffer.pop(0)
|
||||
return decoded
|
||||
|
||||
|
||||
class CoalescingClient(TestClient):
|
||||
"""Client that mimics the websocket js code."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Init TestClient."""
|
||||
super().__init__(*args, ws_response_class=CoalescingResponse, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aiohttp_client_cls():
|
||||
"""Override the test class for aiohttp."""
|
||||
return CoalescingClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aiohttp_client(
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
) -> Generator[AiohttpClient, None, None]:
|
||||
"""Override the default aiohttp_client since 3.x does not support aiohttp_client_cls.
|
||||
|
||||
Remove this when upgrading to 4.x as aiohttp_client_cls
|
||||
will do the same thing
|
||||
|
||||
aiohttp_client(app, **kwargs)
|
||||
aiohttp_client(server, **kwargs)
|
||||
aiohttp_client(raw_server, **kwargs)
|
||||
"""
|
||||
clients = []
|
||||
|
||||
async def go(
|
||||
__param: Application | BaseTestServer,
|
||||
*args: Any,
|
||||
server_kwargs: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> TestClient:
|
||||
|
||||
if isinstance(__param, Callable) and not isinstance( # type: ignore[arg-type]
|
||||
__param, (Application, BaseTestServer)
|
||||
):
|
||||
__param = __param(loop, *args, **kwargs)
|
||||
kwargs = {}
|
||||
else:
|
||||
assert not args, "args should be empty"
|
||||
|
||||
if isinstance(__param, Application):
|
||||
server_kwargs = server_kwargs or {}
|
||||
server = TestServer(__param, loop=loop, **server_kwargs)
|
||||
client = CoalescingClient(server, loop=loop, **kwargs)
|
||||
elif isinstance(__param, BaseTestServer):
|
||||
client = TestClient(__param, loop=loop, **kwargs)
|
||||
else:
|
||||
raise ValueError("Unknown argument type: %r" % type(__param))
|
||||
|
||||
await client.start_server()
|
||||
clients.append(client)
|
||||
return client
|
||||
|
||||
yield go
|
||||
|
||||
async def finalize() -> None:
|
||||
while clients:
|
||||
await clients.pop().close()
|
||||
|
||||
loop.run_until_complete(finalize())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hass(loop, load_registries, hass_storage, request):
|
||||
"""Fixture to provide a test instance of Home Assistant."""
|
||||
|
||||
Reference in New Issue
Block a user