mirror of
https://github.com/home-assistant/core.git
synced 2026-05-08 17:49:37 +01:00
Add custom headers support to downloader (#160541)
Signed-off-by: Pierre PÉRONNET <pierre.peronnet@gmail.com> Co-authored-by: Ariel Ebersberger <ariel@ebersberger.io>
This commit is contained in:
@@ -11,8 +11,7 @@ ATTR_FILENAME = "filename"
|
||||
ATTR_SUBDIR = "subdir"
|
||||
ATTR_URL = "url"
|
||||
ATTR_OVERWRITE = "overwrite"
|
||||
|
||||
CONF_DOWNLOAD_DIR = "download_dir"
|
||||
ATTR_HEADERS = "headers"
|
||||
|
||||
DOWNLOAD_FAILED_EVENT = "download_failed"
|
||||
DOWNLOAD_COMPLETED_EVENT = "download_completed"
|
||||
|
||||
@@ -19,6 +19,7 @@ from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path
|
||||
from .const import (
|
||||
_LOGGER,
|
||||
ATTR_FILENAME,
|
||||
ATTR_HEADERS,
|
||||
ATTR_OVERWRITE,
|
||||
ATTR_SUBDIR,
|
||||
ATTR_URL,
|
||||
@@ -39,6 +40,7 @@ def download_file(service: ServiceCall) -> None:
|
||||
subdir: str | None = service.data.get(ATTR_SUBDIR)
|
||||
target_filename: str | None = service.data.get(ATTR_FILENAME)
|
||||
overwrite: bool = service.data[ATTR_OVERWRITE]
|
||||
headers: dict[str, str] = service.data[ATTR_HEADERS]
|
||||
|
||||
if subdir:
|
||||
# Check the path
|
||||
@@ -62,7 +64,7 @@ def download_file(service: ServiceCall) -> None:
|
||||
final_path = None
|
||||
filename = target_filename
|
||||
try:
|
||||
req = requests.get(url, stream=True, timeout=10)
|
||||
req = requests.get(url, stream=True, headers=headers, timeout=10)
|
||||
|
||||
if req.status_code != HTTPStatus.OK:
|
||||
_LOGGER.warning(
|
||||
@@ -162,6 +164,9 @@ def async_setup_services(hass: HomeAssistant) -> None:
|
||||
vol.Optional(ATTR_SUBDIR): cv.string,
|
||||
vol.Required(ATTR_URL): cv.url,
|
||||
vol.Optional(ATTR_OVERWRITE, default=False): cv.boolean,
|
||||
vol.Optional(ATTR_HEADERS, default=dict): vol.Schema(
|
||||
{cv.string: cv.string}
|
||||
),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
@@ -17,3 +17,9 @@ download_file:
|
||||
default: false
|
||||
selector:
|
||||
boolean:
|
||||
headers:
|
||||
default: {}
|
||||
example:
|
||||
Accept: application/json
|
||||
selector:
|
||||
object:
|
||||
|
||||
@@ -28,6 +28,10 @@
|
||||
"description": "Custom name for the downloaded file.",
|
||||
"name": "Filename"
|
||||
},
|
||||
"headers": {
|
||||
"description": "Additional custom HTTP headers.",
|
||||
"name": "Headers"
|
||||
},
|
||||
"overwrite": {
|
||||
"description": "Overwrite file if it exists.",
|
||||
"name": "Overwrite"
|
||||
|
||||
@@ -4,6 +4,8 @@ import asyncio
|
||||
from contextlib import AbstractContextManager, nullcontext as does_not_raise
|
||||
|
||||
import pytest
|
||||
from requests_mock import Mocker
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.downloader.const import DOMAIN
|
||||
from homeassistant.core import HomeAssistant
|
||||
@@ -52,3 +54,76 @@ async def test_download_invalid_subdir(
|
||||
|
||||
with expected_result:
|
||||
await call_service()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_integration")
|
||||
async def test_download_headers_passed_through(
|
||||
hass: HomeAssistant,
|
||||
requests_mock: Mocker,
|
||||
download_completed: asyncio.Event,
|
||||
download_url: str,
|
||||
) -> None:
|
||||
"""Test that custom headers are passed to the HTTP request."""
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
"download_file",
|
||||
{
|
||||
"url": download_url,
|
||||
"headers": {"Authorization": "Bearer token123", "X-Custom": "value"},
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
await download_completed.wait()
|
||||
|
||||
assert requests_mock.last_request.headers["Authorization"] == "Bearer token123"
|
||||
assert requests_mock.last_request.headers["X-Custom"] == "value"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_integration")
|
||||
@pytest.mark.parametrize(
|
||||
("headers", "expected_result"),
|
||||
[
|
||||
(1, pytest.raises(vol.error.Invalid)), # Not a dictionary
|
||||
({"Accept": "application/json"}, does_not_raise()),
|
||||
({123: 456.789}, does_not_raise()), # Convert numbers to strings
|
||||
(
|
||||
{"Accept": ["application/json"]},
|
||||
pytest.raises(vol.error.MultipleInvalid),
|
||||
), # Value is not a string
|
||||
({1: None}, pytest.raises(vol.error.MultipleInvalid)), # Value is None
|
||||
(
|
||||
{None: "application/json"},
|
||||
pytest.raises(vol.error.MultipleInvalid),
|
||||
), # Key is None
|
||||
],
|
||||
)
|
||||
async def test_download_headers_schema(
|
||||
hass: HomeAssistant,
|
||||
download_completed: asyncio.Event,
|
||||
download_failed: asyncio.Event,
|
||||
download_url: str,
|
||||
headers: dict[str, str],
|
||||
expected_result: AbstractContextManager,
|
||||
) -> None:
|
||||
"""Test service with headers."""
|
||||
|
||||
async def call_service() -> None:
|
||||
"""Call the download service."""
|
||||
completed = hass.async_create_task(download_completed.wait())
|
||||
failed = hass.async_create_task(download_failed.wait())
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
"download_file",
|
||||
{
|
||||
"url": download_url,
|
||||
"headers": headers,
|
||||
"subdir": "test",
|
||||
"filename": "file.txt",
|
||||
"overwrite": True,
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
await asyncio.wait((completed, failed), return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
with expected_result:
|
||||
await call_service()
|
||||
|
||||
Reference in New Issue
Block a user