diff --git a/homeassistant/components/downloader/const.py b/homeassistant/components/downloader/const.py index 14160e4cd5d..69c606a1c09 100644 --- a/homeassistant/components/downloader/const.py +++ b/homeassistant/components/downloader/const.py @@ -11,8 +11,7 @@ ATTR_FILENAME = "filename" ATTR_SUBDIR = "subdir" ATTR_URL = "url" ATTR_OVERWRITE = "overwrite" - -CONF_DOWNLOAD_DIR = "download_dir" +ATTR_HEADERS = "headers" DOWNLOAD_FAILED_EVENT = "download_failed" DOWNLOAD_COMPLETED_EVENT = "download_completed" diff --git a/homeassistant/components/downloader/services.py b/homeassistant/components/downloader/services.py index 0ccaee232d7..74b503bebda 100644 --- a/homeassistant/components/downloader/services.py +++ b/homeassistant/components/downloader/services.py @@ -19,6 +19,7 @@ from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path from .const import ( _LOGGER, ATTR_FILENAME, + ATTR_HEADERS, ATTR_OVERWRITE, ATTR_SUBDIR, ATTR_URL, @@ -39,6 +40,7 @@ def download_file(service: ServiceCall) -> None: subdir: str | None = service.data.get(ATTR_SUBDIR) target_filename: str | None = service.data.get(ATTR_FILENAME) overwrite: bool = service.data[ATTR_OVERWRITE] + headers: dict[str, str] = service.data[ATTR_HEADERS] if subdir: # Check the path @@ -62,7 +64,7 @@ def download_file(service: ServiceCall) -> None: final_path = None filename = target_filename try: - req = requests.get(url, stream=True, timeout=10) + req = requests.get(url, stream=True, headers=headers, timeout=10) if req.status_code != HTTPStatus.OK: _LOGGER.warning( @@ -162,6 +164,9 @@ def async_setup_services(hass: HomeAssistant) -> None: vol.Optional(ATTR_SUBDIR): cv.string, vol.Required(ATTR_URL): cv.url, vol.Optional(ATTR_OVERWRITE, default=False): cv.boolean, + vol.Optional(ATTR_HEADERS, default=dict): vol.Schema( + {cv.string: cv.string} + ), } ), ) diff --git a/homeassistant/components/downloader/services.yaml b/homeassistant/components/downloader/services.yaml index 54d06db5627..24f9f56ec11 100644 --- a/homeassistant/components/downloader/services.yaml +++ b/homeassistant/components/downloader/services.yaml @@ -17,3 +17,9 @@ download_file: default: false selector: boolean: + headers: + default: {} + example: + Accept: application/json + selector: + object: diff --git a/homeassistant/components/downloader/strings.json b/homeassistant/components/downloader/strings.json index 2c1e0352c4e..e18654212a8 100644 --- a/homeassistant/components/downloader/strings.json +++ b/homeassistant/components/downloader/strings.json @@ -28,6 +28,10 @@ "description": "Custom name for the downloaded file.", "name": "Filename" }, + "headers": { + "description": "Additional custom HTTP headers.", + "name": "Headers" + }, "overwrite": { "description": "Overwrite file if it exists.", "name": "Overwrite" diff --git a/tests/components/downloader/test_services.py b/tests/components/downloader/test_services.py index fbdc088021a..6fa75ab95da 100644 --- a/tests/components/downloader/test_services.py +++ b/tests/components/downloader/test_services.py @@ -4,6 +4,8 @@ import asyncio from contextlib import AbstractContextManager, nullcontext as does_not_raise import pytest +from requests_mock import Mocker +import voluptuous as vol from homeassistant.components.downloader.const import DOMAIN from homeassistant.core import HomeAssistant @@ -52,3 +54,76 @@ async def test_download_invalid_subdir( with expected_result: await call_service() + + +@pytest.mark.usefixtures("setup_integration") +async def test_download_headers_passed_through( + hass: HomeAssistant, + requests_mock: Mocker, + download_completed: asyncio.Event, + download_url: str, +) -> None: + """Test that custom headers are passed to the HTTP request.""" + await hass.services.async_call( + DOMAIN, + "download_file", + { + "url": download_url, + "headers": {"Authorization": "Bearer token123", "X-Custom": "value"}, + }, + blocking=True, + ) + await download_completed.wait() + + assert requests_mock.last_request.headers["Authorization"] == "Bearer token123" + assert requests_mock.last_request.headers["X-Custom"] == "value" + + +@pytest.mark.usefixtures("setup_integration") +@pytest.mark.parametrize( + ("headers", "expected_result"), + [ + (1, pytest.raises(vol.error.Invalid)), # Not a dictionary + ({"Accept": "application/json"}, does_not_raise()), + ({123: 456.789}, does_not_raise()), # Convert numbers to strings + ( + {"Accept": ["application/json"]}, + pytest.raises(vol.error.MultipleInvalid), + ), # Value is not a string + ({1: None}, pytest.raises(vol.error.MultipleInvalid)), # Value is None + ( + {None: "application/json"}, + pytest.raises(vol.error.MultipleInvalid), + ), # Key is None + ], +) +async def test_download_headers_schema( + hass: HomeAssistant, + download_completed: asyncio.Event, + download_failed: asyncio.Event, + download_url: str, + headers: dict[str, str], + expected_result: AbstractContextManager, +) -> None: + """Test service with headers.""" + + async def call_service() -> None: + """Call the download service.""" + completed = hass.async_create_task(download_completed.wait()) + failed = hass.async_create_task(download_failed.wait()) + await hass.services.async_call( + DOMAIN, + "download_file", + { + "url": download_url, + "headers": headers, + "subdir": "test", + "filename": "file.txt", + "overwrite": True, + }, + blocking=True, + ) + await asyncio.wait((completed, failed), return_when=asyncio.FIRST_COMPLETED) + + with expected_result: + await call_service()