diff --git a/homeassistant/helpers/aiohttp_client.py b/homeassistant/helpers/aiohttp_client.py index f7ccdbeae6e..cf40441bf5f 100644 --- a/homeassistant/helpers/aiohttp_client.py +++ b/homeassistant/helpers/aiohttp_client.py @@ -3,8 +3,10 @@ from __future__ import annotations import asyncio -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Sequence from contextlib import suppress +from functools import lru_cache +from ipaddress import ip_address import socket from ssl import SSLContext import sys @@ -12,10 +14,11 @@ from types import MappingProxyType from typing import TYPE_CHECKING, Any, Self import aiohttp -from aiohttp import web +from aiohttp import ClientMiddlewareType, hdrs, web from aiohttp.hdrs import CONTENT_TYPE, USER_AGENT from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout from aiohttp_asyncmdnsresolver.api import AsyncDualMDNSResolver +from yarl import URL from homeassistant import config_entries from homeassistant.components import zeroconf @@ -25,6 +28,7 @@ from homeassistant.loader import bind_hass from homeassistant.util import ssl as ssl_util from homeassistant.util.hass_dict import HassKey from homeassistant.util.json import json_loads +from homeassistant.util.network import is_loopback from .frame import warn_use from .json import json_dumps @@ -49,6 +53,92 @@ SERVER_SOFTWARE = ( WARN_CLOSE_MSG = "closes the Home Assistant aiohttp session" +_LOCALHOST = "localhost" +_TRAILING_LOCAL_HOST = f".{_LOCALHOST}" + + +class SSRFRedirectError(aiohttp.ClientError): + """SSRF redirect protection. + + Raised when a redirect targets a blocked address (loopback or unspecified). + """ + + +async def _ssrf_redirect_middleware( + request: aiohttp.ClientRequest, + handler: aiohttp.ClientHandlerType, +) -> aiohttp.ClientResponse: + """Block redirects from non-loopback origins to loopback targets.""" + resp = await handler(request) + + # Return early if not a redirect or already loopback to allow loopback origins + connector = request.session.connector + if not (300 <= resp.status < 400) or await _async_is_blocked_host( + request.url.host, connector + ): + return resp + + location = resp.headers.get(hdrs.LOCATION, "") + if not location: + return resp + + redirect_url = URL(location) + if not redirect_url.is_absolute(): + # Relative redirects stay on the same host - always safe + return resp + + host = redirect_url.host + if await _async_is_blocked_host(host, connector): + resp.close() + raise SSRFRedirectError( + f"Redirect from {request.url.host} to a blocked address" + f" is not allowed: {host}" + ) + + return resp + + +@lru_cache +def _is_ssrf_address(address: str) -> bool: + """Check if an IP address is a potential SSRF target. + + Returns True for loopback and unspecified addresses. + """ + ip = ip_address(address) + return is_loopback(ip) or ip.is_unspecified + + +async def _async_is_blocked_host( + host: str | None, connector: aiohttp.BaseConnector | None +) -> bool: + """Check if a host is blocked by hostname or by resolved IP. + + First does a fast sync check on the hostname string, then resolves + the hostname via the connector and checks each resolved IP address. + """ + if not host: + return False + + # Strip FQDN trailing dot (RFC 1035) since yarl preserves it, + # preventing an attacker from bypassing the check with "localhost." + stripped_host = host.strip().removesuffix(".") + if stripped_host == _LOCALHOST or stripped_host.endswith(_TRAILING_LOCAL_HOST): + return True + + with suppress(ValueError): + return _is_ssrf_address(host) + + if not isinstance(connector, HomeAssistantTCPConnector): + return False + + try: + results = await connector.async_resolve_host(host) + except Exception: # noqa: BLE001 + return False + + return any(_is_ssrf_address(result["host"]) for result in results) + + # # The default connection limit of 100 meant that you could only have # 100 concurrent connections. @@ -191,10 +281,16 @@ def _async_create_clientsession( **kwargs: Any, ) -> aiohttp.ClientSession: """Create a new ClientSession with kwargs, i.e. for cookies.""" + middlewares: Sequence[ClientMiddlewareType] = ( + _ssrf_redirect_middleware, + *kwargs.pop("middlewares", ()), + ) + clientsession = aiohttp.ClientSession( connector=_async_get_connector(hass, verify_ssl, family, ssl_cipher), json_serialize=json_dumps, response_class=HassClientResponse, + middlewares=middlewares, **kwargs, ) # Prevent packages accidentally overriding our default headers @@ -343,6 +439,10 @@ class HomeAssistantTCPConnector(aiohttp.TCPConnector): # abort transport after 60 seconds (cleanup broken connections) _cleanup_closed_period = 60.0 + async def async_resolve_host(self, host: str) -> list[aiohttp.abc.ResolveResult]: + """Resolve a host to a list of addresses.""" + return await self._resolve_host(host, 0) + @callback def _async_get_connector( diff --git a/tests/helpers/test_aiohttp_client.py b/tests/helpers/test_aiohttp_client.py index b75850a3626..0862d3c1e76 100644 --- a/tests/helpers/test_aiohttp_client.py +++ b/tests/helpers/test_aiohttp_client.py @@ -1,10 +1,12 @@ """Test the aiohttp client helper.""" +from collections.abc import AsyncGenerator import socket from unittest.mock import Mock, patch import aiohttp -from aiohttp.test_utils import TestClient +from aiohttp import web +from aiohttp.test_utils import TestClient, TestServer import pytest from homeassistant.components.mjpeg import ( @@ -440,3 +442,179 @@ async def test_connector_no_verify_uses_http11_alpn(hass: HomeAssistant) -> None mock_client_context_no_verify.assert_called_once_with( SSLCipherList.PYTHON_DEFAULT, ssl_util.SSL_ALPN_HTTP11 ) + + +@pytest.fixture +async def redirect_server() -> AsyncGenerator[TestServer]: + """Start a test server that redirects based on query parameters.""" + + async def handle_redirect(request: web.Request) -> web.Response: + """Redirect to the URL specified in the 'to' query parameter.""" + location = request.query["to"] + return web.Response(status=307, headers={"Location": location}) + + async def handle_ok(request: web.Request) -> web.Response: + """Return a 200 OK response.""" + return web.Response(text="ok") + + app = web.Application() + app.router.add_get("/redirect", handle_redirect) + app.router.add_get("/ok", handle_ok) + + async def _mock_resolve_host( + self: aiohttp.TCPConnector, + host: str, + port: int, + traces: object = None, + ) -> list[dict[str, object]]: + return [ + { + "hostname": host, + "host": "127.0.0.1", + "port": port, + "family": socket.AF_INET, + "proto": 6, + "flags": 0, + } + ] + + server = TestServer(app) + await server.start_server() + # Route all TCP connections to the local test server + # This allows us to test redirect behavior of external URLs + # without actually making network requests + with patch.object(aiohttp.TCPConnector, "_resolve_host", _mock_resolve_host): + yield server + await server.close() + + +def _resolve_result(host: str, addr: str) -> list[dict[str, object]]: + """Build a mock DNS resolve result for the SSRF check.""" + return [ + { + "hostname": host, + "host": addr, + "port": 0, + "family": socket.AF_INET, + "proto": 6, + "flags": 0, + } + ] + + +@pytest.mark.usefixtures("socket_enabled") +async def test_redirect_loopback_to_loopback_allowed( + hass: HomeAssistant, redirect_server: TestServer +) -> None: + """Test that redirects from loopback to loopback are allowed.""" + session = client.async_get_clientsession(hass) + target = str(redirect_server.make_url("/ok")) + redirect_url = redirect_server.make_url(f"/redirect?to={target}") + + # Both origin and target are on 127.0.0.1 — should be allowed + resp = await session.get(redirect_url) + assert resp.status == 200 + + +@pytest.mark.usefixtures("socket_enabled") +async def test_redirect_relative_url_allowed( + hass: HomeAssistant, redirect_server: TestServer +) -> None: + """Test that relative redirects are allowed (they stay on the same host).""" + session = client.async_create_clientsession(hass) + server_port = redirect_server.port + + # Redirect from an external origin to a relative path + redirect_url = f"http://external.example.com:{server_port}/redirect?to=/ok" + + async def mock_async_resolve_host(host: str) -> list[dict[str, object]]: + """Return public IPs for all hosts.""" + return _resolve_result(host, "93.184.216.34") + + connector = session.connector + with patch.object(connector, "async_resolve_host", mock_async_resolve_host): + resp = await session.get(redirect_url) + assert resp.status == 200 + + +@pytest.mark.usefixtures("socket_enabled") +@pytest.mark.parametrize( + "target", + [ + "http://other.example.com:{port}/ok", + "http://safe.example.com:{port}/ok", + "http://notlocalhost:{port}/ok", + ], +) +async def test_redirect_to_non_loopback_allowed( + hass: HomeAssistant, redirect_server: TestServer, target: str +) -> None: + """Test that redirects to non-loopback addresses are allowed.""" + session = client.async_create_clientsession(hass) + server_port = redirect_server.port + + location = target.format(port=server_port) + redirect_url = f"http://external.example.com:{server_port}/redirect?to={location}" + + async def mock_async_resolve_host(host: str) -> list[dict[str, object]]: + """Return public IPs for all hosts.""" + return _resolve_result(host, "93.184.216.34") + + connector = session.connector + with patch.object(connector, "async_resolve_host", mock_async_resolve_host): + resp = await session.get(redirect_url) + assert resp.status == 200 + + +@pytest.mark.usefixtures("socket_enabled") +@pytest.mark.parametrize( + ("location", "target_resolved_addr"), + [ + # Loopback IPs and hostnames — blocked before DNS resolution + ("http://127.0.0.1/evil", None), + ("http://[::1]/evil", None), + ("http://localhost/evil", None), + ("http://localhost./evil", None), + ("http://example.localhost/evil", None), + ("http://example.localhost./evil", None), + ("http://app.localhost/evil", None), + ("http://sub.domain.localhost/evil", None), + # Benign hostnames resolving to blocked IPs — blocked after DNS + ("http://evil.example.com:{port}/steal", "127.0.0.1"), + ("http://evil.example.com:{port}/steal", "127.0.0.2"), + ("http://evil.example.com:{port}/steal", "::1"), + ("http://evil.example.com:{port}/steal", "0.0.0.0"), + ("http://evil.example.com:{port}/steal", "::"), + ], +) +async def test_redirect_to_blocked_address( + hass: HomeAssistant, + redirect_server: TestServer, + location: str, + target_resolved_addr: str | None, +) -> None: + """Test that redirects to blocked addresses are blocked. + + Covers both cases: targets blocked by hostname/IP (before DNS) and + targets blocked after DNS resolution reveals a loopback/unspecified IP. + """ + session = client.async_create_clientsession(hass) + server_port = redirect_server.port + + target = location.format(port=server_port) + redirect_url = f"http://external.example.com:{server_port}/redirect?to={target}" + + async def mock_async_resolve_host(host: str) -> list[dict[str, object]]: + """Return public IP for origin, optional blocked IP for target.""" + if host == "external.example.com": + return _resolve_result(host, "93.184.216.34") + if target_resolved_addr is not None: + return _resolve_result(host, target_resolved_addr) + return [] + + connector = session.connector + with ( + patch.object(connector, "async_resolve_host", mock_async_resolve_host), + pytest.raises(client.SSRFRedirectError), + ): + await session.get(redirect_url)