mirror of
https://github.com/home-assistant/core.git
synced 2026-05-08 17:49:37 +01:00
Verify local_only webhook on MockRequest (#169271)
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -374,6 +374,7 @@ class CloudClient(Interface):
|
||||
method=payload["method"],
|
||||
query_string=payload["query"],
|
||||
mock_source=DOMAIN,
|
||||
remote=None, # Remote will be used for the local_only check, but since this is from the cloud we want it to be None to mark it as non-local and bypass the ip parsing and remote checks
|
||||
)
|
||||
|
||||
response = await webhook.async_handle_webhook(
|
||||
|
||||
@@ -120,8 +120,11 @@ async def async_handle_webhook(
|
||||
handlers: dict[str, dict[str, Any]] = hass.data.setdefault(DOMAIN, {})
|
||||
|
||||
content_stream: StreamReader | MockStreamReader
|
||||
received_from: str | None
|
||||
if isinstance(request, MockRequest):
|
||||
received_from = request.mock_source
|
||||
if request.remote is not None:
|
||||
received_from += f" ({request.remote})"
|
||||
content_stream = request.content
|
||||
method_name = request.method
|
||||
else:
|
||||
@@ -156,11 +159,11 @@ async def async_handle_webhook(
|
||||
)
|
||||
return Response(status=HTTPStatus.METHOD_NOT_ALLOWED)
|
||||
|
||||
if webhook["local_only"] in (True, None) and not isinstance(request, MockRequest):
|
||||
is_local = not is_cloud_connection(hass)
|
||||
if webhook["local_only"] in (True, None):
|
||||
is_local = not (is_cloud_connection(hass) or request.remote is None)
|
||||
|
||||
if is_local:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(request, Request)
|
||||
assert request.remote is not None
|
||||
|
||||
try:
|
||||
@@ -273,6 +276,7 @@ async def websocket_handle(
|
||||
method=msg["method"],
|
||||
query_string=msg["query"],
|
||||
mock_source=f"{DOMAIN}/ws",
|
||||
remote=connection.remote,
|
||||
)
|
||||
|
||||
response = await async_handle_webhook(hass, msg["webhook_id"], request)
|
||||
|
||||
@@ -78,6 +78,7 @@ class AuthPhase:
|
||||
self._send_message,
|
||||
self._request[KEY_HASS_USER],
|
||||
refresh_token=None,
|
||||
remote=self._request.remote,
|
||||
)
|
||||
await self._send_bytes_text(AUTH_OK_MESSAGE)
|
||||
self._logger.debug("Auth OK (unix socket)")
|
||||
@@ -111,6 +112,7 @@ class AuthPhase:
|
||||
self._send_message,
|
||||
refresh_token.user,
|
||||
refresh_token,
|
||||
remote=self._request.remote,
|
||||
)
|
||||
conn.subscriptions["auth"] = (
|
||||
self._hass.auth.async_register_revoke_token_callback(
|
||||
|
||||
@@ -47,6 +47,7 @@ class ActiveConnection:
|
||||
"last_id",
|
||||
"logger",
|
||||
"refresh_token_id",
|
||||
"remote",
|
||||
"send_message",
|
||||
"subscriptions",
|
||||
"supported_features",
|
||||
@@ -60,6 +61,7 @@ class ActiveConnection:
|
||||
send_message: Callable[[bytes | str | dict[str, Any]], None],
|
||||
user: User,
|
||||
refresh_token: RefreshToken | None,
|
||||
remote: str | None,
|
||||
) -> None:
|
||||
"""Initialize an active connection."""
|
||||
self.logger = logger
|
||||
@@ -67,6 +69,7 @@ class ActiveConnection:
|
||||
self.send_message = send_message
|
||||
self.user = user
|
||||
self.refresh_token_id = refresh_token.id if refresh_token else None
|
||||
self.remote = remote
|
||||
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
|
||||
self.last_id = 0
|
||||
self.can_coalesce = False
|
||||
|
||||
@@ -58,7 +58,7 @@ _MOCK_PAYLOAD_WRITER = MockPayloadWriter()
|
||||
class MockRequest:
|
||||
"""Mock an aiohttp request."""
|
||||
|
||||
mock_source: str | None = None
|
||||
mock_source: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -69,6 +69,7 @@ class MockRequest:
|
||||
headers: dict[str, str] | None = None,
|
||||
query_string: str | None = None,
|
||||
url: str = "",
|
||||
remote: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize a request."""
|
||||
self.method = method
|
||||
@@ -81,6 +82,7 @@ class MockRequest:
|
||||
self._content = content
|
||||
self.mock_source = mock_source
|
||||
self._payload_writer = _MOCK_PAYLOAD_WRITER
|
||||
self.remote = remote
|
||||
|
||||
async def _prepare_hook(self, response: Any) -> None:
|
||||
"""Prepare hook."""
|
||||
|
||||
@@ -316,6 +316,51 @@ async def test_webhook_msg(
|
||||
assert '{"nonexisting": "payload"}' in caplog.text
|
||||
|
||||
|
||||
async def test_webhook_msg_local_only(hass: HomeAssistant) -> None:
|
||||
"""Test a cloudhook for a local_only webhook does not fire the handler."""
|
||||
with patch("hass_nabucasa.Cloud.initialize"):
|
||||
setup = await async_setup_component(hass, "cloud", {"cloud": {}})
|
||||
assert setup
|
||||
cloud = hass.data[DATA_CLOUD]
|
||||
|
||||
await cloud.client.prefs.async_initialize()
|
||||
await cloud.client.prefs.async_update(
|
||||
cloudhooks={
|
||||
"mock-webhook-id": {
|
||||
"webhook_id": "mock-webhook-id",
|
||||
"cloudhook_id": "mock-cloud-id",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
received = []
|
||||
|
||||
async def handler(
|
||||
hass: HomeAssistant, webhook_id: str, request: web.Request
|
||||
) -> web.Response:
|
||||
"""Handle a webhook."""
|
||||
received.append(request)
|
||||
return web.json_response({"from": "handler"})
|
||||
|
||||
webhook.async_register(
|
||||
hass, "test", "Test", "mock-webhook-id", handler, local_only=True
|
||||
)
|
||||
|
||||
response = await cloud.client.async_webhook_message(
|
||||
{
|
||||
"cloudhook_id": "mock-cloud-id",
|
||||
"body": '{"hello": "world"}',
|
||||
"headers": {"content-type": CONTENT_TYPE_JSON},
|
||||
"method": "POST",
|
||||
"query": None,
|
||||
}
|
||||
)
|
||||
|
||||
assert response["status"] == 200
|
||||
# Handler not called because cloudhooks are not considered local
|
||||
assert len(received) == 0
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_cloud_setup", "mock_cloud_login")
|
||||
async def test_google_config_expose_entity(
|
||||
hass: HomeAssistant,
|
||||
|
||||
@@ -9,10 +9,13 @@ from aiohttp.test_utils import TestClient
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import webhook
|
||||
from homeassistant.components.websocket_api import auth, http
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core_config import async_process_ha_core_config
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util.aiohttp import MockRequest
|
||||
|
||||
from tests.test_util import mock_real_ip
|
||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||
|
||||
|
||||
@@ -267,6 +270,45 @@ async def test_webhook_local_only(hass: HomeAssistant, mock_client) -> None:
|
||||
assert len(hooks) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("remote", "expected_calls"),
|
||||
[
|
||||
(None, 0),
|
||||
("123.123.123.123", 0),
|
||||
("not-an-ip", 0),
|
||||
("192.168.1.50", 1),
|
||||
],
|
||||
)
|
||||
async def test_webhook_local_only_mock_request(
|
||||
hass: HomeAssistant, remote: str | None, expected_calls: int
|
||||
) -> None:
|
||||
"""Test local_only webhooks for MockRequests with various remote values."""
|
||||
await async_setup_component(hass, "webhook", {})
|
||||
|
||||
hooks = []
|
||||
webhook_id = webhook.async_generate_id()
|
||||
|
||||
async def handle(hass: HomeAssistant, webhook_id: str, request: web.Request):
|
||||
"""Handle webhook."""
|
||||
hooks.append((hass, webhook_id, await request.text()))
|
||||
|
||||
webhook.async_register(
|
||||
hass, "test", "Test hook", webhook_id, handle, local_only=True
|
||||
)
|
||||
|
||||
request = MockRequest(
|
||||
content=b'{"data": true}',
|
||||
headers={"Content-Type": "application/json"},
|
||||
method="POST",
|
||||
query_string="",
|
||||
mock_source="test",
|
||||
remote=remote,
|
||||
)
|
||||
resp = await webhook.async_handle_webhook(hass, webhook_id, request)
|
||||
assert resp.status == HTTPStatus.OK
|
||||
assert len(hooks) == expected_calls
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("enable_custom_integrations")
|
||||
async def test_listing_webhook(
|
||||
hass: HomeAssistant,
|
||||
@@ -356,6 +398,8 @@ async def test_ws_webhook(
|
||||
assert received[0].headers["content-type"] == "application/json"
|
||||
assert received[0].query == {"a": "2"}
|
||||
assert await received[0].json() == {"hello": "world"}
|
||||
# The MockRequest is created with the websocket connection's remote IP
|
||||
assert received[0].remote is not None
|
||||
|
||||
# Non existing webhook
|
||||
caplog.clear()
|
||||
@@ -383,3 +427,68 @@ async def test_ws_webhook(
|
||||
in caplog.text
|
||||
)
|
||||
assert '{"nonexisting": "payload"}' in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("remote_ip", "expected_calls"),
|
||||
[
|
||||
("192.168.1.50", 1),
|
||||
("123.123.123.123", 0),
|
||||
],
|
||||
)
|
||||
async def test_ws_webhook_local_only(
|
||||
hass: HomeAssistant,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
hass_access_token: str,
|
||||
remote_ip: str,
|
||||
expected_calls: int,
|
||||
) -> None:
|
||||
"""Test a local_only webhook over the websocket connection."""
|
||||
assert await async_setup_component(hass, "webhook", {})
|
||||
assert await async_setup_component(hass, "websocket_api", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
received = []
|
||||
|
||||
async def handler(
|
||||
hass: HomeAssistant, webhook_id: str, request: web.Request
|
||||
) -> web.Response:
|
||||
"""Handle a webhook."""
|
||||
received.append(request)
|
||||
return web.json_response({"from": "handler"})
|
||||
|
||||
webhook.async_register(
|
||||
hass, "test", "Test", "mock-webhook-id", handler, local_only=True
|
||||
)
|
||||
|
||||
set_mock_ip = mock_real_ip(hass.http.app)
|
||||
set_mock_ip(remote_ip)
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
|
||||
async with client.ws_connect(http.URL) as ws:
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg["type"] == auth.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({"type": auth.TYPE_AUTH, "access_token": hass_access_token})
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg["type"] == auth.TYPE_AUTH_OK
|
||||
|
||||
await ws.send_json(
|
||||
{
|
||||
"id": 5,
|
||||
"type": "webhook/handle",
|
||||
"webhook_id": "mock-webhook-id",
|
||||
"method": "POST",
|
||||
"headers": {"Content-Type": "application/json"},
|
||||
"body": '{"hello": "world"}',
|
||||
"query": "",
|
||||
}
|
||||
)
|
||||
result = await ws.receive_json()
|
||||
|
||||
assert result["success"], result
|
||||
assert result["result"]["status"] == HTTPStatus.OK
|
||||
assert len(received) == expected_calls
|
||||
if expected_calls:
|
||||
assert received[0].remote == remote_ip
|
||||
|
||||
@@ -100,7 +100,12 @@ async def test_exception_handling(
|
||||
) as current_request:
|
||||
current_request.get.return_value = mocked_request
|
||||
conn = websocket_api.ActiveConnection(
|
||||
logging.getLogger(__name__), hass, send_messages.append, user, refresh_token
|
||||
logging.getLogger(__name__),
|
||||
hass,
|
||||
send_messages.append,
|
||||
user,
|
||||
refresh_token,
|
||||
remote="127.0.0.42",
|
||||
)
|
||||
|
||||
conn.async_handle_exception({"id": 5}, exc)
|
||||
@@ -113,7 +118,7 @@ async def test_exception_handling(
|
||||
async def test_binary_handler_registration() -> None:
|
||||
"""Test binary handler registration."""
|
||||
connection = websocket_api.ActiveConnection(
|
||||
None, Mock(data={websocket_api.DOMAIN: None}), None, None, Mock()
|
||||
None, Mock(data={websocket_api.DOMAIN: None}), None, None, Mock(), remote=None
|
||||
)
|
||||
|
||||
# One filler to align indexes with prefix numbers
|
||||
|
||||
Reference in New Issue
Block a user