1
0
mirror of https://github.com/home-assistant/core.git synced 2026-05-08 17:49:37 +01:00

Verify local_only webhook on MockRequest (#169271)

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
Robert Resch
2026-04-27 14:49:48 +02:00
committed by GitHub
parent 642206699d
commit 8a22e84db0
8 changed files with 177 additions and 6 deletions
+1
View File
@@ -374,6 +374,7 @@ class CloudClient(Interface):
method=payload["method"],
query_string=payload["query"],
mock_source=DOMAIN,
remote=None, # Remote will be used for the local_only check, but since this is from the cloud we want it to be None to mark it as non-local and bypass the ip parsing and remote checks
)
response = await webhook.async_handle_webhook(
+7 -3
View File
@@ -120,8 +120,11 @@ async def async_handle_webhook(
handlers: dict[str, dict[str, Any]] = hass.data.setdefault(DOMAIN, {})
content_stream: StreamReader | MockStreamReader
received_from: str | None
if isinstance(request, MockRequest):
received_from = request.mock_source
if request.remote is not None:
received_from += f" ({request.remote})"
content_stream = request.content
method_name = request.method
else:
@@ -156,11 +159,11 @@ async def async_handle_webhook(
)
return Response(status=HTTPStatus.METHOD_NOT_ALLOWED)
if webhook["local_only"] in (True, None) and not isinstance(request, MockRequest):
is_local = not is_cloud_connection(hass)
if webhook["local_only"] in (True, None):
is_local = not (is_cloud_connection(hass) or request.remote is None)
if is_local:
if TYPE_CHECKING:
assert isinstance(request, Request)
assert request.remote is not None
try:
@@ -273,6 +276,7 @@ async def websocket_handle(
method=msg["method"],
query_string=msg["query"],
mock_source=f"{DOMAIN}/ws",
remote=connection.remote,
)
response = await async_handle_webhook(hass, msg["webhook_id"], request)
@@ -78,6 +78,7 @@ class AuthPhase:
self._send_message,
self._request[KEY_HASS_USER],
refresh_token=None,
remote=self._request.remote,
)
await self._send_bytes_text(AUTH_OK_MESSAGE)
self._logger.debug("Auth OK (unix socket)")
@@ -111,6 +112,7 @@ class AuthPhase:
self._send_message,
refresh_token.user,
refresh_token,
remote=self._request.remote,
)
conn.subscriptions["auth"] = (
self._hass.auth.async_register_revoke_token_callback(
@@ -47,6 +47,7 @@ class ActiveConnection:
"last_id",
"logger",
"refresh_token_id",
"remote",
"send_message",
"subscriptions",
"supported_features",
@@ -60,6 +61,7 @@ class ActiveConnection:
send_message: Callable[[bytes | str | dict[str, Any]], None],
user: User,
refresh_token: RefreshToken | None,
remote: str | None,
) -> None:
"""Initialize an active connection."""
self.logger = logger
@@ -67,6 +69,7 @@ class ActiveConnection:
self.send_message = send_message
self.user = user
self.refresh_token_id = refresh_token.id if refresh_token else None
self.remote = remote
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
self.last_id = 0
self.can_coalesce = False
+3 -1
View File
@@ -58,7 +58,7 @@ _MOCK_PAYLOAD_WRITER = MockPayloadWriter()
class MockRequest:
"""Mock an aiohttp request."""
mock_source: str | None = None
mock_source: str
def __init__(
self,
@@ -69,6 +69,7 @@ class MockRequest:
headers: dict[str, str] | None = None,
query_string: str | None = None,
url: str = "",
remote: str | None = None,
) -> None:
"""Initialize a request."""
self.method = method
@@ -81,6 +82,7 @@ class MockRequest:
self._content = content
self.mock_source = mock_source
self._payload_writer = _MOCK_PAYLOAD_WRITER
self.remote = remote
async def _prepare_hook(self, response: Any) -> None:
"""Prepare hook."""
+45
View File
@@ -316,6 +316,51 @@ async def test_webhook_msg(
assert '{"nonexisting": "payload"}' in caplog.text
async def test_webhook_msg_local_only(hass: HomeAssistant) -> None:
"""Test a cloudhook for a local_only webhook does not fire the handler."""
with patch("hass_nabucasa.Cloud.initialize"):
setup = await async_setup_component(hass, "cloud", {"cloud": {}})
assert setup
cloud = hass.data[DATA_CLOUD]
await cloud.client.prefs.async_initialize()
await cloud.client.prefs.async_update(
cloudhooks={
"mock-webhook-id": {
"webhook_id": "mock-webhook-id",
"cloudhook_id": "mock-cloud-id",
},
}
)
received = []
async def handler(
hass: HomeAssistant, webhook_id: str, request: web.Request
) -> web.Response:
"""Handle a webhook."""
received.append(request)
return web.json_response({"from": "handler"})
webhook.async_register(
hass, "test", "Test", "mock-webhook-id", handler, local_only=True
)
response = await cloud.client.async_webhook_message(
{
"cloudhook_id": "mock-cloud-id",
"body": '{"hello": "world"}',
"headers": {"content-type": CONTENT_TYPE_JSON},
"method": "POST",
"query": None,
}
)
assert response["status"] == 200
# Handler not called because cloudhooks are not considered local
assert len(received) == 0
@pytest.mark.usefixtures("mock_cloud_setup", "mock_cloud_login")
async def test_google_config_expose_entity(
hass: HomeAssistant,
+109
View File
@@ -9,10 +9,13 @@ from aiohttp.test_utils import TestClient
import pytest
from homeassistant.components import webhook
from homeassistant.components.websocket_api import auth, http
from homeassistant.core import HomeAssistant
from homeassistant.core_config import async_process_ha_core_config
from homeassistant.setup import async_setup_component
from homeassistant.util.aiohttp import MockRequest
from tests.test_util import mock_real_ip
from tests.typing import ClientSessionGenerator, WebSocketGenerator
@@ -267,6 +270,45 @@ async def test_webhook_local_only(hass: HomeAssistant, mock_client) -> None:
assert len(hooks) == 1
@pytest.mark.parametrize(
("remote", "expected_calls"),
[
(None, 0),
("123.123.123.123", 0),
("not-an-ip", 0),
("192.168.1.50", 1),
],
)
async def test_webhook_local_only_mock_request(
hass: HomeAssistant, remote: str | None, expected_calls: int
) -> None:
"""Test local_only webhooks for MockRequests with various remote values."""
await async_setup_component(hass, "webhook", {})
hooks = []
webhook_id = webhook.async_generate_id()
async def handle(hass: HomeAssistant, webhook_id: str, request: web.Request):
"""Handle webhook."""
hooks.append((hass, webhook_id, await request.text()))
webhook.async_register(
hass, "test", "Test hook", webhook_id, handle, local_only=True
)
request = MockRequest(
content=b'{"data": true}',
headers={"Content-Type": "application/json"},
method="POST",
query_string="",
mock_source="test",
remote=remote,
)
resp = await webhook.async_handle_webhook(hass, webhook_id, request)
assert resp.status == HTTPStatus.OK
assert len(hooks) == expected_calls
@pytest.mark.usefixtures("enable_custom_integrations")
async def test_listing_webhook(
hass: HomeAssistant,
@@ -356,6 +398,8 @@ async def test_ws_webhook(
assert received[0].headers["content-type"] == "application/json"
assert received[0].query == {"a": "2"}
assert await received[0].json() == {"hello": "world"}
# The MockRequest is created with the websocket connection's remote IP
assert received[0].remote is not None
# Non existing webhook
caplog.clear()
@@ -383,3 +427,68 @@ async def test_ws_webhook(
in caplog.text
)
assert '{"nonexisting": "payload"}' in caplog.text
@pytest.mark.parametrize(
("remote_ip", "expected_calls"),
[
("192.168.1.50", 1),
("123.123.123.123", 0),
],
)
async def test_ws_webhook_local_only(
hass: HomeAssistant,
hass_client_no_auth: ClientSessionGenerator,
hass_access_token: str,
remote_ip: str,
expected_calls: int,
) -> None:
"""Test a local_only webhook over the websocket connection."""
assert await async_setup_component(hass, "webhook", {})
assert await async_setup_component(hass, "websocket_api", {})
await hass.async_block_till_done()
received = []
async def handler(
hass: HomeAssistant, webhook_id: str, request: web.Request
) -> web.Response:
"""Handle a webhook."""
received.append(request)
return web.json_response({"from": "handler"})
webhook.async_register(
hass, "test", "Test", "mock-webhook-id", handler, local_only=True
)
set_mock_ip = mock_real_ip(hass.http.app)
set_mock_ip(remote_ip)
client = await hass_client_no_auth()
async with client.ws_connect(http.URL) as ws:
auth_msg = await ws.receive_json()
assert auth_msg["type"] == auth.TYPE_AUTH_REQUIRED
await ws.send_json({"type": auth.TYPE_AUTH, "access_token": hass_access_token})
auth_msg = await ws.receive_json()
assert auth_msg["type"] == auth.TYPE_AUTH_OK
await ws.send_json(
{
"id": 5,
"type": "webhook/handle",
"webhook_id": "mock-webhook-id",
"method": "POST",
"headers": {"Content-Type": "application/json"},
"body": '{"hello": "world"}',
"query": "",
}
)
result = await ws.receive_json()
assert result["success"], result
assert result["result"]["status"] == HTTPStatus.OK
assert len(received) == expected_calls
if expected_calls:
assert received[0].remote == remote_ip
@@ -100,7 +100,12 @@ async def test_exception_handling(
) as current_request:
current_request.get.return_value = mocked_request
conn = websocket_api.ActiveConnection(
logging.getLogger(__name__), hass, send_messages.append, user, refresh_token
logging.getLogger(__name__),
hass,
send_messages.append,
user,
refresh_token,
remote="127.0.0.42",
)
conn.async_handle_exception({"id": 5}, exc)
@@ -113,7 +118,7 @@ async def test_exception_handling(
async def test_binary_handler_registration() -> None:
"""Test binary handler registration."""
connection = websocket_api.ActiveConnection(
None, Mock(data={websocket_api.DOMAIN: None}), None, None, Mock()
None, Mock(data={websocket_api.DOMAIN: None}), None, None, Mock(), remote=None
)
# One filler to align indexes with prefix numbers