diff --git a/homeassistant/components/cloud/client.py b/homeassistant/components/cloud/client.py index b1c3bebcaae..51bc0bd4e39 100644 --- a/homeassistant/components/cloud/client.py +++ b/homeassistant/components/cloud/client.py @@ -374,6 +374,7 @@ class CloudClient(Interface): method=payload["method"], query_string=payload["query"], mock_source=DOMAIN, + remote=None, # Remote will be used for the local_only check, but since this is from the cloud we want it to be None to mark it as non-local and bypass the ip parsing and remote checks ) response = await webhook.async_handle_webhook( diff --git a/homeassistant/components/webhook/__init__.py b/homeassistant/components/webhook/__init__.py index 5778658b128..68625a98448 100644 --- a/homeassistant/components/webhook/__init__.py +++ b/homeassistant/components/webhook/__init__.py @@ -120,8 +120,11 @@ async def async_handle_webhook( handlers: dict[str, dict[str, Any]] = hass.data.setdefault(DOMAIN, {}) content_stream: StreamReader | MockStreamReader + received_from: str | None if isinstance(request, MockRequest): received_from = request.mock_source + if request.remote is not None: + received_from += f" ({request.remote})" content_stream = request.content method_name = request.method else: @@ -156,11 +159,11 @@ async def async_handle_webhook( ) return Response(status=HTTPStatus.METHOD_NOT_ALLOWED) - if webhook["local_only"] in (True, None) and not isinstance(request, MockRequest): - is_local = not is_cloud_connection(hass) + if webhook["local_only"] in (True, None): + is_local = not (is_cloud_connection(hass) or request.remote is None) + if is_local: if TYPE_CHECKING: - assert isinstance(request, Request) assert request.remote is not None try: @@ -273,6 +276,7 @@ async def websocket_handle( method=msg["method"], query_string=msg["query"], mock_source=f"{DOMAIN}/ws", + remote=connection.remote, ) response = await async_handle_webhook(hass, msg["webhook_id"], request) diff --git a/homeassistant/components/websocket_api/auth.py b/homeassistant/components/websocket_api/auth.py index ae4844cd69a..dfb16e16e95 100644 --- a/homeassistant/components/websocket_api/auth.py +++ b/homeassistant/components/websocket_api/auth.py @@ -78,6 +78,7 @@ class AuthPhase: self._send_message, self._request[KEY_HASS_USER], refresh_token=None, + remote=self._request.remote, ) await self._send_bytes_text(AUTH_OK_MESSAGE) self._logger.debug("Auth OK (unix socket)") @@ -111,6 +112,7 @@ class AuthPhase: self._send_message, refresh_token.user, refresh_token, + remote=self._request.remote, ) conn.subscriptions["auth"] = ( self._hass.auth.async_register_revoke_token_callback( diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index dad8ebe5686..cf67e70c2d6 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -47,6 +47,7 @@ class ActiveConnection: "last_id", "logger", "refresh_token_id", + "remote", "send_message", "subscriptions", "supported_features", @@ -60,6 +61,7 @@ class ActiveConnection: send_message: Callable[[bytes | str | dict[str, Any]], None], user: User, refresh_token: RefreshToken | None, + remote: str | None, ) -> None: """Initialize an active connection.""" self.logger = logger @@ -67,6 +69,7 @@ class ActiveConnection: self.send_message = send_message self.user = user self.refresh_token_id = refresh_token.id if refresh_token else None + self.remote = remote self.subscriptions: dict[Hashable, Callable[[], Any]] = {} self.last_id = 0 self.can_coalesce = False diff --git a/homeassistant/util/aiohttp.py b/homeassistant/util/aiohttp.py index 888da368053..b65f64d6ecb 100644 --- a/homeassistant/util/aiohttp.py +++ b/homeassistant/util/aiohttp.py @@ -58,7 +58,7 @@ _MOCK_PAYLOAD_WRITER = MockPayloadWriter() class MockRequest: """Mock an aiohttp request.""" - mock_source: str | None = None + mock_source: str def __init__( self, @@ -69,6 +69,7 @@ class MockRequest: headers: dict[str, str] | None = None, query_string: str | None = None, url: str = "", + remote: str | None = None, ) -> None: """Initialize a request.""" self.method = method @@ -81,6 +82,7 @@ class MockRequest: self._content = content self.mock_source = mock_source self._payload_writer = _MOCK_PAYLOAD_WRITER + self.remote = remote async def _prepare_hook(self, response: Any) -> None: """Prepare hook.""" diff --git a/tests/components/cloud/test_client.py b/tests/components/cloud/test_client.py index 283e2ff39f1..8b4d29155fa 100644 --- a/tests/components/cloud/test_client.py +++ b/tests/components/cloud/test_client.py @@ -316,6 +316,51 @@ async def test_webhook_msg( assert '{"nonexisting": "payload"}' in caplog.text +async def test_webhook_msg_local_only(hass: HomeAssistant) -> None: + """Test a cloudhook for a local_only webhook does not fire the handler.""" + with patch("hass_nabucasa.Cloud.initialize"): + setup = await async_setup_component(hass, "cloud", {"cloud": {}}) + assert setup + cloud = hass.data[DATA_CLOUD] + + await cloud.client.prefs.async_initialize() + await cloud.client.prefs.async_update( + cloudhooks={ + "mock-webhook-id": { + "webhook_id": "mock-webhook-id", + "cloudhook_id": "mock-cloud-id", + }, + } + ) + + received = [] + + async def handler( + hass: HomeAssistant, webhook_id: str, request: web.Request + ) -> web.Response: + """Handle a webhook.""" + received.append(request) + return web.json_response({"from": "handler"}) + + webhook.async_register( + hass, "test", "Test", "mock-webhook-id", handler, local_only=True + ) + + response = await cloud.client.async_webhook_message( + { + "cloudhook_id": "mock-cloud-id", + "body": '{"hello": "world"}', + "headers": {"content-type": CONTENT_TYPE_JSON}, + "method": "POST", + "query": None, + } + ) + + assert response["status"] == 200 + # Handler not called because cloudhooks are not considered local + assert len(received) == 0 + + @pytest.mark.usefixtures("mock_cloud_setup", "mock_cloud_login") async def test_google_config_expose_entity( hass: HomeAssistant, diff --git a/tests/components/webhook/test_init.py b/tests/components/webhook/test_init.py index 20fe5024962..6300d640157 100644 --- a/tests/components/webhook/test_init.py +++ b/tests/components/webhook/test_init.py @@ -9,10 +9,13 @@ from aiohttp.test_utils import TestClient import pytest from homeassistant.components import webhook +from homeassistant.components.websocket_api import auth, http from homeassistant.core import HomeAssistant from homeassistant.core_config import async_process_ha_core_config from homeassistant.setup import async_setup_component +from homeassistant.util.aiohttp import MockRequest +from tests.test_util import mock_real_ip from tests.typing import ClientSessionGenerator, WebSocketGenerator @@ -267,6 +270,45 @@ async def test_webhook_local_only(hass: HomeAssistant, mock_client) -> None: assert len(hooks) == 1 +@pytest.mark.parametrize( + ("remote", "expected_calls"), + [ + (None, 0), + ("123.123.123.123", 0), + ("not-an-ip", 0), + ("192.168.1.50", 1), + ], +) +async def test_webhook_local_only_mock_request( + hass: HomeAssistant, remote: str | None, expected_calls: int +) -> None: + """Test local_only webhooks for MockRequests with various remote values.""" + await async_setup_component(hass, "webhook", {}) + + hooks = [] + webhook_id = webhook.async_generate_id() + + async def handle(hass: HomeAssistant, webhook_id: str, request: web.Request): + """Handle webhook.""" + hooks.append((hass, webhook_id, await request.text())) + + webhook.async_register( + hass, "test", "Test hook", webhook_id, handle, local_only=True + ) + + request = MockRequest( + content=b'{"data": true}', + headers={"Content-Type": "application/json"}, + method="POST", + query_string="", + mock_source="test", + remote=remote, + ) + resp = await webhook.async_handle_webhook(hass, webhook_id, request) + assert resp.status == HTTPStatus.OK + assert len(hooks) == expected_calls + + @pytest.mark.usefixtures("enable_custom_integrations") async def test_listing_webhook( hass: HomeAssistant, @@ -356,6 +398,8 @@ async def test_ws_webhook( assert received[0].headers["content-type"] == "application/json" assert received[0].query == {"a": "2"} assert await received[0].json() == {"hello": "world"} + # The MockRequest is created with the websocket connection's remote IP + assert received[0].remote is not None # Non existing webhook caplog.clear() @@ -383,3 +427,68 @@ async def test_ws_webhook( in caplog.text ) assert '{"nonexisting": "payload"}' in caplog.text + + +@pytest.mark.parametrize( + ("remote_ip", "expected_calls"), + [ + ("192.168.1.50", 1), + ("123.123.123.123", 0), + ], +) +async def test_ws_webhook_local_only( + hass: HomeAssistant, + hass_client_no_auth: ClientSessionGenerator, + hass_access_token: str, + remote_ip: str, + expected_calls: int, +) -> None: + """Test a local_only webhook over the websocket connection.""" + assert await async_setup_component(hass, "webhook", {}) + assert await async_setup_component(hass, "websocket_api", {}) + await hass.async_block_till_done() + + received = [] + + async def handler( + hass: HomeAssistant, webhook_id: str, request: web.Request + ) -> web.Response: + """Handle a webhook.""" + received.append(request) + return web.json_response({"from": "handler"}) + + webhook.async_register( + hass, "test", "Test", "mock-webhook-id", handler, local_only=True + ) + + set_mock_ip = mock_real_ip(hass.http.app) + set_mock_ip(remote_ip) + + client = await hass_client_no_auth() + + async with client.ws_connect(http.URL) as ws: + auth_msg = await ws.receive_json() + assert auth_msg["type"] == auth.TYPE_AUTH_REQUIRED + + await ws.send_json({"type": auth.TYPE_AUTH, "access_token": hass_access_token}) + auth_msg = await ws.receive_json() + assert auth_msg["type"] == auth.TYPE_AUTH_OK + + await ws.send_json( + { + "id": 5, + "type": "webhook/handle", + "webhook_id": "mock-webhook-id", + "method": "POST", + "headers": {"Content-Type": "application/json"}, + "body": '{"hello": "world"}', + "query": "", + } + ) + result = await ws.receive_json() + + assert result["success"], result + assert result["result"]["status"] == HTTPStatus.OK + assert len(received) == expected_calls + if expected_calls: + assert received[0].remote == remote_ip diff --git a/tests/components/websocket_api/test_connection.py b/tests/components/websocket_api/test_connection.py index 343575e5b4a..74950f43ba7 100644 --- a/tests/components/websocket_api/test_connection.py +++ b/tests/components/websocket_api/test_connection.py @@ -100,7 +100,12 @@ async def test_exception_handling( ) as current_request: current_request.get.return_value = mocked_request conn = websocket_api.ActiveConnection( - logging.getLogger(__name__), hass, send_messages.append, user, refresh_token + logging.getLogger(__name__), + hass, + send_messages.append, + user, + refresh_token, + remote="127.0.0.42", ) conn.async_handle_exception({"id": 5}, exc) @@ -113,7 +118,7 @@ async def test_exception_handling( async def test_binary_handler_registration() -> None: """Test binary handler registration.""" connection = websocket_api.ActiveConnection( - None, Mock(data={websocket_api.DOMAIN: None}), None, None, Mock() + None, Mock(data={websocket_api.DOMAIN: None}), None, None, Mock(), remote=None ) # One filler to align indexes with prefix numbers