mirror of
https://github.com/home-assistant/core.git
synced 2026-02-14 23:28:42 +00:00
Block redirect to localhost (#162941)
This commit is contained in:
@@ -3,8 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from contextlib import suppress
|
||||
from functools import lru_cache
|
||||
from ipaddress import ip_address
|
||||
import socket
|
||||
from ssl import SSLContext
|
||||
import sys
|
||||
@@ -12,10 +14,11 @@ from types import MappingProxyType
|
||||
from typing import TYPE_CHECKING, Any, Self
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
from aiohttp import ClientMiddlewareType, hdrs, web
|
||||
from aiohttp.hdrs import CONTENT_TYPE, USER_AGENT
|
||||
from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout
|
||||
from aiohttp_asyncmdnsresolver.api import AsyncDualMDNSResolver
|
||||
from yarl import URL
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components import zeroconf
|
||||
@@ -25,6 +28,7 @@ from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import ssl as ssl_util
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import json_loads
|
||||
from homeassistant.util.network import is_loopback
|
||||
|
||||
from .frame import warn_use
|
||||
from .json import json_dumps
|
||||
@@ -49,6 +53,92 @@ SERVER_SOFTWARE = (
|
||||
|
||||
WARN_CLOSE_MSG = "closes the Home Assistant aiohttp session"
|
||||
|
||||
_LOCALHOST = "localhost"
|
||||
_TRAILING_LOCAL_HOST = f".{_LOCALHOST}"
|
||||
|
||||
|
||||
class SSRFRedirectError(aiohttp.ClientError):
|
||||
"""SSRF redirect protection.
|
||||
|
||||
Raised when a redirect targets a blocked address (loopback or unspecified).
|
||||
"""
|
||||
|
||||
|
||||
async def _ssrf_redirect_middleware(
|
||||
request: aiohttp.ClientRequest,
|
||||
handler: aiohttp.ClientHandlerType,
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""Block redirects from non-loopback origins to loopback targets."""
|
||||
resp = await handler(request)
|
||||
|
||||
# Return early if not a redirect or already loopback to allow loopback origins
|
||||
connector = request.session.connector
|
||||
if not (300 <= resp.status < 400) or await _async_is_blocked_host(
|
||||
request.url.host, connector
|
||||
):
|
||||
return resp
|
||||
|
||||
location = resp.headers.get(hdrs.LOCATION, "")
|
||||
if not location:
|
||||
return resp
|
||||
|
||||
redirect_url = URL(location)
|
||||
if not redirect_url.is_absolute():
|
||||
# Relative redirects stay on the same host - always safe
|
||||
return resp
|
||||
|
||||
host = redirect_url.host
|
||||
if await _async_is_blocked_host(host, connector):
|
||||
resp.close()
|
||||
raise SSRFRedirectError(
|
||||
f"Redirect from {request.url.host} to a blocked address"
|
||||
f" is not allowed: {host}"
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _is_ssrf_address(address: str) -> bool:
|
||||
"""Check if an IP address is a potential SSRF target.
|
||||
|
||||
Returns True for loopback and unspecified addresses.
|
||||
"""
|
||||
ip = ip_address(address)
|
||||
return is_loopback(ip) or ip.is_unspecified
|
||||
|
||||
|
||||
async def _async_is_blocked_host(
|
||||
host: str | None, connector: aiohttp.BaseConnector | None
|
||||
) -> bool:
|
||||
"""Check if a host is blocked by hostname or by resolved IP.
|
||||
|
||||
First does a fast sync check on the hostname string, then resolves
|
||||
the hostname via the connector and checks each resolved IP address.
|
||||
"""
|
||||
if not host:
|
||||
return False
|
||||
|
||||
# Strip FQDN trailing dot (RFC 1035) since yarl preserves it,
|
||||
# preventing an attacker from bypassing the check with "localhost."
|
||||
stripped_host = host.strip().removesuffix(".")
|
||||
if stripped_host == _LOCALHOST or stripped_host.endswith(_TRAILING_LOCAL_HOST):
|
||||
return True
|
||||
|
||||
with suppress(ValueError):
|
||||
return _is_ssrf_address(host)
|
||||
|
||||
if not isinstance(connector, HomeAssistantTCPConnector):
|
||||
return False
|
||||
|
||||
try:
|
||||
results = await connector.async_resolve_host(host)
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
return any(_is_ssrf_address(result["host"]) for result in results)
|
||||
|
||||
|
||||
#
|
||||
# The default connection limit of 100 meant that you could only have
|
||||
# 100 concurrent connections.
|
||||
@@ -191,10 +281,16 @@ def _async_create_clientsession(
|
||||
**kwargs: Any,
|
||||
) -> aiohttp.ClientSession:
|
||||
"""Create a new ClientSession with kwargs, i.e. for cookies."""
|
||||
middlewares: Sequence[ClientMiddlewareType] = (
|
||||
_ssrf_redirect_middleware,
|
||||
*kwargs.pop("middlewares", ()),
|
||||
)
|
||||
|
||||
clientsession = aiohttp.ClientSession(
|
||||
connector=_async_get_connector(hass, verify_ssl, family, ssl_cipher),
|
||||
json_serialize=json_dumps,
|
||||
response_class=HassClientResponse,
|
||||
middlewares=middlewares,
|
||||
**kwargs,
|
||||
)
|
||||
# Prevent packages accidentally overriding our default headers
|
||||
@@ -343,6 +439,10 @@ class HomeAssistantTCPConnector(aiohttp.TCPConnector):
|
||||
# abort transport after 60 seconds (cleanup broken connections)
|
||||
_cleanup_closed_period = 60.0
|
||||
|
||||
async def async_resolve_host(self, host: str) -> list[aiohttp.abc.ResolveResult]:
|
||||
"""Resolve a host to a list of addresses."""
|
||||
return await self._resolve_host(host, 0)
|
||||
|
||||
|
||||
@callback
|
||||
def _async_get_connector(
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""Test the aiohttp client helper."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
import socket
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import aiohttp
|
||||
from aiohttp.test_utils import TestClient
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.mjpeg import (
|
||||
@@ -440,3 +442,179 @@ async def test_connector_no_verify_uses_http11_alpn(hass: HomeAssistant) -> None
|
||||
mock_client_context_no_verify.assert_called_once_with(
|
||||
SSLCipherList.PYTHON_DEFAULT, ssl_util.SSL_ALPN_HTTP11
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def redirect_server() -> AsyncGenerator[TestServer]:
|
||||
"""Start a test server that redirects based on query parameters."""
|
||||
|
||||
async def handle_redirect(request: web.Request) -> web.Response:
|
||||
"""Redirect to the URL specified in the 'to' query parameter."""
|
||||
location = request.query["to"]
|
||||
return web.Response(status=307, headers={"Location": location})
|
||||
|
||||
async def handle_ok(request: web.Request) -> web.Response:
|
||||
"""Return a 200 OK response."""
|
||||
return web.Response(text="ok")
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/redirect", handle_redirect)
|
||||
app.router.add_get("/ok", handle_ok)
|
||||
|
||||
async def _mock_resolve_host(
|
||||
self: aiohttp.TCPConnector,
|
||||
host: str,
|
||||
port: int,
|
||||
traces: object = None,
|
||||
) -> list[dict[str, object]]:
|
||||
return [
|
||||
{
|
||||
"hostname": host,
|
||||
"host": "127.0.0.1",
|
||||
"port": port,
|
||||
"family": socket.AF_INET,
|
||||
"proto": 6,
|
||||
"flags": 0,
|
||||
}
|
||||
]
|
||||
|
||||
server = TestServer(app)
|
||||
await server.start_server()
|
||||
# Route all TCP connections to the local test server
|
||||
# This allows us to test redirect behavior of external URLs
|
||||
# without actually making network requests
|
||||
with patch.object(aiohttp.TCPConnector, "_resolve_host", _mock_resolve_host):
|
||||
yield server
|
||||
await server.close()
|
||||
|
||||
|
||||
def _resolve_result(host: str, addr: str) -> list[dict[str, object]]:
|
||||
"""Build a mock DNS resolve result for the SSRF check."""
|
||||
return [
|
||||
{
|
||||
"hostname": host,
|
||||
"host": addr,
|
||||
"port": 0,
|
||||
"family": socket.AF_INET,
|
||||
"proto": 6,
|
||||
"flags": 0,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("socket_enabled")
|
||||
async def test_redirect_loopback_to_loopback_allowed(
|
||||
hass: HomeAssistant, redirect_server: TestServer
|
||||
) -> None:
|
||||
"""Test that redirects from loopback to loopback are allowed."""
|
||||
session = client.async_get_clientsession(hass)
|
||||
target = str(redirect_server.make_url("/ok"))
|
||||
redirect_url = redirect_server.make_url(f"/redirect?to={target}")
|
||||
|
||||
# Both origin and target are on 127.0.0.1 — should be allowed
|
||||
resp = await session.get(redirect_url)
|
||||
assert resp.status == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("socket_enabled")
|
||||
async def test_redirect_relative_url_allowed(
|
||||
hass: HomeAssistant, redirect_server: TestServer
|
||||
) -> None:
|
||||
"""Test that relative redirects are allowed (they stay on the same host)."""
|
||||
session = client.async_create_clientsession(hass)
|
||||
server_port = redirect_server.port
|
||||
|
||||
# Redirect from an external origin to a relative path
|
||||
redirect_url = f"http://external.example.com:{server_port}/redirect?to=/ok"
|
||||
|
||||
async def mock_async_resolve_host(host: str) -> list[dict[str, object]]:
|
||||
"""Return public IPs for all hosts."""
|
||||
return _resolve_result(host, "93.184.216.34")
|
||||
|
||||
connector = session.connector
|
||||
with patch.object(connector, "async_resolve_host", mock_async_resolve_host):
|
||||
resp = await session.get(redirect_url)
|
||||
assert resp.status == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("socket_enabled")
|
||||
@pytest.mark.parametrize(
|
||||
"target",
|
||||
[
|
||||
"http://other.example.com:{port}/ok",
|
||||
"http://safe.example.com:{port}/ok",
|
||||
"http://notlocalhost:{port}/ok",
|
||||
],
|
||||
)
|
||||
async def test_redirect_to_non_loopback_allowed(
|
||||
hass: HomeAssistant, redirect_server: TestServer, target: str
|
||||
) -> None:
|
||||
"""Test that redirects to non-loopback addresses are allowed."""
|
||||
session = client.async_create_clientsession(hass)
|
||||
server_port = redirect_server.port
|
||||
|
||||
location = target.format(port=server_port)
|
||||
redirect_url = f"http://external.example.com:{server_port}/redirect?to={location}"
|
||||
|
||||
async def mock_async_resolve_host(host: str) -> list[dict[str, object]]:
|
||||
"""Return public IPs for all hosts."""
|
||||
return _resolve_result(host, "93.184.216.34")
|
||||
|
||||
connector = session.connector
|
||||
with patch.object(connector, "async_resolve_host", mock_async_resolve_host):
|
||||
resp = await session.get(redirect_url)
|
||||
assert resp.status == 200
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("socket_enabled")
|
||||
@pytest.mark.parametrize(
|
||||
("location", "target_resolved_addr"),
|
||||
[
|
||||
# Loopback IPs and hostnames — blocked before DNS resolution
|
||||
("http://127.0.0.1/evil", None),
|
||||
("http://[::1]/evil", None),
|
||||
("http://localhost/evil", None),
|
||||
("http://localhost./evil", None),
|
||||
("http://example.localhost/evil", None),
|
||||
("http://example.localhost./evil", None),
|
||||
("http://app.localhost/evil", None),
|
||||
("http://sub.domain.localhost/evil", None),
|
||||
# Benign hostnames resolving to blocked IPs — blocked after DNS
|
||||
("http://evil.example.com:{port}/steal", "127.0.0.1"),
|
||||
("http://evil.example.com:{port}/steal", "127.0.0.2"),
|
||||
("http://evil.example.com:{port}/steal", "::1"),
|
||||
("http://evil.example.com:{port}/steal", "0.0.0.0"),
|
||||
("http://evil.example.com:{port}/steal", "::"),
|
||||
],
|
||||
)
|
||||
async def test_redirect_to_blocked_address(
|
||||
hass: HomeAssistant,
|
||||
redirect_server: TestServer,
|
||||
location: str,
|
||||
target_resolved_addr: str | None,
|
||||
) -> None:
|
||||
"""Test that redirects to blocked addresses are blocked.
|
||||
|
||||
Covers both cases: targets blocked by hostname/IP (before DNS) and
|
||||
targets blocked after DNS resolution reveals a loopback/unspecified IP.
|
||||
"""
|
||||
session = client.async_create_clientsession(hass)
|
||||
server_port = redirect_server.port
|
||||
|
||||
target = location.format(port=server_port)
|
||||
redirect_url = f"http://external.example.com:{server_port}/redirect?to={target}"
|
||||
|
||||
async def mock_async_resolve_host(host: str) -> list[dict[str, object]]:
|
||||
"""Return public IP for origin, optional blocked IP for target."""
|
||||
if host == "external.example.com":
|
||||
return _resolve_result(host, "93.184.216.34")
|
||||
if target_resolved_addr is not None:
|
||||
return _resolve_result(host, target_resolved_addr)
|
||||
return []
|
||||
|
||||
connector = session.connector
|
||||
with (
|
||||
patch.object(connector, "async_resolve_host", mock_async_resolve_host),
|
||||
pytest.raises(client.SSRFRedirectError),
|
||||
):
|
||||
await session.get(redirect_url)
|
||||
|
||||
Reference in New Issue
Block a user