1
0
mirror of https://github.com/home-assistant/core.git synced 2026-05-08 17:49:37 +01:00

Add custom headers support to downloader (#160541)

Signed-off-by: Pierre PÉRONNET <pierre.peronnet@gmail.com>
Co-authored-by: Ariel Ebersberger <ariel@ebersberger.io>
This commit is contained in:
Pierre PÉRONNET
2026-02-19 12:29:30 +01:00
committed by GitHub
parent 4615b4d104
commit b194741a13
5 changed files with 92 additions and 3 deletions
+1 -2
View File
@@ -11,8 +11,7 @@ ATTR_FILENAME = "filename"
ATTR_SUBDIR = "subdir"
ATTR_URL = "url"
ATTR_OVERWRITE = "overwrite"
CONF_DOWNLOAD_DIR = "download_dir"
ATTR_HEADERS = "headers"
DOWNLOAD_FAILED_EVENT = "download_failed"
DOWNLOAD_COMPLETED_EVENT = "download_completed"
@@ -19,6 +19,7 @@ from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path
from .const import (
_LOGGER,
ATTR_FILENAME,
ATTR_HEADERS,
ATTR_OVERWRITE,
ATTR_SUBDIR,
ATTR_URL,
@@ -39,6 +40,7 @@ def download_file(service: ServiceCall) -> None:
subdir: str | None = service.data.get(ATTR_SUBDIR)
target_filename: str | None = service.data.get(ATTR_FILENAME)
overwrite: bool = service.data[ATTR_OVERWRITE]
headers: dict[str, str] = service.data[ATTR_HEADERS]
if subdir:
# Check the path
@@ -62,7 +64,7 @@ def download_file(service: ServiceCall) -> None:
final_path = None
filename = target_filename
try:
req = requests.get(url, stream=True, timeout=10)
req = requests.get(url, stream=True, headers=headers, timeout=10)
if req.status_code != HTTPStatus.OK:
_LOGGER.warning(
@@ -162,6 +164,9 @@ def async_setup_services(hass: HomeAssistant) -> None:
vol.Optional(ATTR_SUBDIR): cv.string,
vol.Required(ATTR_URL): cv.url,
vol.Optional(ATTR_OVERWRITE, default=False): cv.boolean,
vol.Optional(ATTR_HEADERS, default=dict): vol.Schema(
{cv.string: cv.string}
),
}
),
)
@@ -17,3 +17,9 @@ download_file:
default: false
selector:
boolean:
headers:
default: {}
example:
Accept: application/json
selector:
object:
@@ -28,6 +28,10 @@
"description": "Custom name for the downloaded file.",
"name": "Filename"
},
"headers": {
"description": "Additional custom HTTP headers.",
"name": "Headers"
},
"overwrite": {
"description": "Overwrite file if it exists.",
"name": "Overwrite"
@@ -4,6 +4,8 @@ import asyncio
from contextlib import AbstractContextManager, nullcontext as does_not_raise
import pytest
from requests_mock import Mocker
import voluptuous as vol
from homeassistant.components.downloader.const import DOMAIN
from homeassistant.core import HomeAssistant
@@ -52,3 +54,76 @@ async def test_download_invalid_subdir(
with expected_result:
await call_service()
@pytest.mark.usefixtures("setup_integration")
async def test_download_headers_passed_through(
hass: HomeAssistant,
requests_mock: Mocker,
download_completed: asyncio.Event,
download_url: str,
) -> None:
"""Test that custom headers are passed to the HTTP request."""
await hass.services.async_call(
DOMAIN,
"download_file",
{
"url": download_url,
"headers": {"Authorization": "Bearer token123", "X-Custom": "value"},
},
blocking=True,
)
await download_completed.wait()
assert requests_mock.last_request.headers["Authorization"] == "Bearer token123"
assert requests_mock.last_request.headers["X-Custom"] == "value"
@pytest.mark.usefixtures("setup_integration")
@pytest.mark.parametrize(
("headers", "expected_result"),
[
(1, pytest.raises(vol.error.Invalid)), # Not a dictionary
({"Accept": "application/json"}, does_not_raise()),
({123: 456.789}, does_not_raise()), # Convert numbers to strings
(
{"Accept": ["application/json"]},
pytest.raises(vol.error.MultipleInvalid),
), # Value is not a string
({1: None}, pytest.raises(vol.error.MultipleInvalid)), # Value is None
(
{None: "application/json"},
pytest.raises(vol.error.MultipleInvalid),
), # Key is None
],
)
async def test_download_headers_schema(
hass: HomeAssistant,
download_completed: asyncio.Event,
download_failed: asyncio.Event,
download_url: str,
headers: dict[str, str],
expected_result: AbstractContextManager,
) -> None:
"""Test service with headers."""
async def call_service() -> None:
"""Call the download service."""
completed = hass.async_create_task(download_completed.wait())
failed = hass.async_create_task(download_failed.wait())
await hass.services.async_call(
DOMAIN,
"download_file",
{
"url": download_url,
"headers": headers,
"subdir": "test",
"filename": "file.txt",
"overwrite": True,
},
blocking=True,
)
await asyncio.wait((completed, failed), return_when=asyncio.FIRST_COMPLETED)
with expected_result:
await call_service()