1
0
mirror of https://github.com/home-assistant/core.git synced 2026-04-02 08:26:41 +01:00

Add upload progress tracking to S3 integrations (#166325)

This commit is contained in:
Josef Zweck
2026-03-24 14:44:05 +01:00
committed by GitHub
parent f38f3626fb
commit 1e9c8ec32c
6 changed files with 230 additions and 11 deletions

View File

@@ -147,7 +147,7 @@ class S3BackupAgent(BackupAgent):
if backup.size < MULTIPART_MIN_PART_SIZE_BYTES:
await self._upload_simple(tar_filename, open_stream)
else:
await self._upload_multipart(tar_filename, open_stream)
await self._upload_multipart(tar_filename, open_stream, on_progress)
# Upload the metadata file
metadata_content = json.dumps(backup.as_dict())
@@ -188,11 +188,13 @@ class S3BackupAgent(BackupAgent):
self,
tar_filename: str,
open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]],
):
on_progress: OnProgressCallback,
) -> None:
"""Upload a large file using multipart upload.
:param tar_filename: The target filename for the backup.
:param open_stream: A function returning an async iterator that yields bytes.
:param on_progress: A callback to report the number of uploaded bytes.
"""
_LOGGER.debug("Starting multipart upload for %s", tar_filename)
multipart_upload = await self._client.create_multipart_upload(
@@ -205,6 +207,7 @@ class S3BackupAgent(BackupAgent):
part_number = 1
buffer = bytearray() # bytes buffer to store the data
offset = 0 # start index of unread data inside buffer
bytes_uploaded = 0
stream = await open_stream()
async for chunk in stream:
@@ -233,6 +236,8 @@ class S3BackupAgent(BackupAgent):
Body=part_data.tobytes(),
)
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
bytes_uploaded += len(part_data)
on_progress(bytes_uploaded=bytes_uploaded)
part_number += 1
finally:
view.release()
@@ -261,6 +266,8 @@ class S3BackupAgent(BackupAgent):
Body=remaining_data.tobytes(),
)
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
bytes_uploaded += len(remaining_data)
on_progress(bytes_uploaded=bytes_uploaded)
await cast(Any, self._client).complete_multipart_upload(
Bucket=self._bucket,

View File

@@ -144,7 +144,7 @@ class R2BackupAgent(BackupAgent):
if backup.size < MULTIPART_MIN_PART_SIZE_BYTES:
await self._upload_simple(tar_filename, open_stream)
else:
await self._upload_multipart(tar_filename, open_stream)
await self._upload_multipart(tar_filename, open_stream, on_progress)
# Upload the metadata file
metadata_content = json.dumps(backup.as_dict())
@@ -185,11 +185,13 @@ class R2BackupAgent(BackupAgent):
self,
tar_filename: str,
open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]],
):
on_progress: OnProgressCallback,
) -> None:
"""Upload a large file using multipart upload.
:param tar_filename: The target filename for the backup.
:param open_stream: A function returning an async iterator that yields bytes.
:param on_progress: A callback to report the number of uploaded bytes.
"""
_LOGGER.debug("Starting multipart upload for %s", tar_filename)
key = self._with_prefix(tar_filename)
@@ -203,6 +205,7 @@ class R2BackupAgent(BackupAgent):
part_number = 1
buffer = bytearray() # bytes buffer to store the data
offset = 0 # start index of unread data inside buffer
bytes_uploaded = 0
stream = await open_stream()
async for chunk in stream:
@@ -231,6 +234,8 @@ class R2BackupAgent(BackupAgent):
Body=part_data.tobytes(),
)
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
bytes_uploaded += len(part_data)
on_progress(bytes_uploaded=bytes_uploaded)
part_number += 1
finally:
view.release()
@@ -259,6 +264,8 @@ class R2BackupAgent(BackupAgent):
Body=remaining_data.tobytes(),
)
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
bytes_uploaded += len(remaining_data)
on_progress(bytes_uploaded=bytes_uploaded)
await cast(Any, self._client).complete_multipart_upload(
Bucket=self._bucket,

View File

@@ -142,7 +142,7 @@ class IDriveE2BackupAgent(BackupAgent):
if backup.size < MULTIPART_MIN_PART_SIZE_BYTES:
await self._upload_simple(tar_filename, open_stream)
else:
await self._upload_multipart(tar_filename, open_stream)
await self._upload_multipart(tar_filename, open_stream, on_progress)
# Upload the metadata file
metadata_content = json.dumps(backup.as_dict())
@@ -183,11 +183,13 @@ class IDriveE2BackupAgent(BackupAgent):
self,
tar_filename: str,
open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]],
on_progress: OnProgressCallback,
) -> None:
"""Upload a large file using multipart upload.
:param tar_filename: The target filename for the backup.
:param open_stream: A function returning an async iterator that yields bytes.
:param on_progress: A callback to report the number of uploaded bytes.
"""
_LOGGER.debug("Starting multipart upload for %s", tar_filename)
multipart_upload = await cast(Any, self._client).create_multipart_upload(
@@ -200,6 +202,7 @@ class IDriveE2BackupAgent(BackupAgent):
part_number = 1
buffer = bytearray() # bytes buffer to store the data
offset = 0 # start index of unread data inside buffer
bytes_uploaded = 0
stream = await open_stream()
async for chunk in stream:
@@ -228,6 +231,8 @@ class IDriveE2BackupAgent(BackupAgent):
Body=part_data.tobytes(),
)
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
bytes_uploaded += len(part_data)
on_progress(bytes_uploaded=bytes_uploaded)
part_number += 1
finally:
view.release()
@@ -256,6 +261,8 @@ class IDriveE2BackupAgent(BackupAgent):
Body=remaining_data.tobytes(),
)
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
bytes_uploaded += len(remaining_data)
on_progress(bytes_uploaded=bytes_uploaded)
await cast(Any, self._client).complete_multipart_upload(
Bucket=self._bucket,

View File

@@ -21,7 +21,12 @@ from homeassistant.components.aws_s3.const import (
DATA_BACKUP_AGENT_LISTENERS,
DOMAIN,
)
from homeassistant.components.backup import DOMAIN as BACKUP_DOMAIN, AgentBackup
from homeassistant.components.backup import (
DATA_MANAGER,
DOMAIN as BACKUP_DOMAIN,
AgentBackup,
UploadBackupEvent,
)
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
@@ -332,6 +337,65 @@ async def test_agents_upload_network_failure(
assert "Upload failed for aws_s3" in caplog.text
@pytest.mark.parametrize(
"backup_size", [MULTIPART_MIN_PART_SIZE_BYTES * 2], ids=["large"]
)
async def test_agents_upload_on_progress(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_client: MagicMock,
mock_config_entry: MockConfigEntry,
mock_agent_backup: AgentBackup,
) -> None:
"""Test agent upload backup emits UploadBackupEvent via on_progress."""
client = await hass_client()
manager = hass.data[DATA_MANAGER]
events: list[UploadBackupEvent] = []
def _collect(event: UploadBackupEvent) -> None:
if isinstance(event, UploadBackupEvent):
events.append(event)
unsub = manager.async_subscribe_events(_collect)
with (
patch(
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
return_value=mock_agent_backup,
),
patch(
"homeassistant.components.backup.manager.read_backup",
return_value=mock_agent_backup,
),
patch("pathlib.Path.open") as mocked_open,
):
mocked_open.return_value.read = Mock(
side_effect=[
b"a" * mock_agent_backup.size,
b"",
]
)
resp = await client.post(
f"/api/backup/upload?agent_id={DOMAIN}.{mock_config_entry.entry_id}",
data={"file": StringIO("test")},
)
unsub()
assert resp.status == 201
agent_id = f"{DOMAIN}.{mock_config_entry.entry_id}"
agent_events = [e for e in events if e.agent_id == agent_id]
assert len(agent_events) >= 2
assert all(e.total_bytes == mock_agent_backup.size for e in agent_events)
# Verify events report distinct increasing byte counts
uploaded_bytes = [e.uploaded_bytes for e in agent_events]
assert uploaded_bytes == sorted(uploaded_bytes)
assert len(set(uploaded_bytes)) == len(uploaded_bytes)
# Verify at least one intermediate event (uploaded_bytes < total_bytes)
assert agent_events[0].uploaded_bytes < agent_events[0].total_bytes
async def test_agents_download(
hass_client: ClientSessionGenerator,
mock_client: MagicMock,

View File

@@ -9,7 +9,12 @@ from unittest.mock import AsyncMock, Mock, patch
from botocore.exceptions import BotoCoreError, ConnectTimeoutError
import pytest
from homeassistant.components.backup import DOMAIN as BACKUP_DOMAIN, AgentBackup
from homeassistant.components.backup import (
DATA_MANAGER,
DOMAIN as BACKUP_DOMAIN,
AgentBackup,
UploadBackupEvent,
)
from homeassistant.components.cloudflare_r2.backup import (
MULTIPART_MIN_PART_SIZE_BYTES,
R2BackupAgent,
@@ -400,7 +405,7 @@ async def test_multipart_upload_consistent_part_sizes(
mock_client.upload_part.side_effect = record_upload_part
await agent._upload_multipart("test.tar", open_stream)
await agent._upload_multipart("test.tar", open_stream, Mock())
# Verify that all non-trailing parts have the same size
assert len(uploaded_part_sizes) >= 2, "Expected at least 2 parts"
@@ -417,6 +422,68 @@ async def test_multipart_upload_consistent_part_sizes(
assert uploaded_part_sizes[-1] == expected_trailing
@pytest.mark.parametrize(
"test_backup",
[MULTIPART_MIN_PART_SIZE_BYTES * 2],
indirect=True,
ids=["large"],
)
async def test_agents_upload_on_progress(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_client: MagicMock,
mock_config_entry: MockConfigEntry,
test_backup: AgentBackup,
) -> None:
"""Test agent upload backup emits UploadBackupEvent via on_progress."""
client = await hass_client()
manager = hass.data[DATA_MANAGER]
events: list[UploadBackupEvent] = []
def _collect(event: UploadBackupEvent) -> None:
if isinstance(event, UploadBackupEvent):
events.append(event)
unsub = manager.async_subscribe_events(_collect)
with (
patch(
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
return_value=test_backup,
),
patch(
"homeassistant.components.backup.manager.read_backup",
return_value=test_backup,
),
patch("pathlib.Path.open") as mocked_open,
):
mocked_open.return_value.read = Mock(
side_effect=[
b"a" * test_backup.size,
b"",
]
)
resp = await client.post(
f"/api/backup/upload?agent_id={DOMAIN}.{mock_config_entry.entry_id}",
data={"file": StringIO("test")},
)
unsub()
assert resp.status == 201
agent_id = f"{DOMAIN}.{mock_config_entry.entry_id}"
agent_events = [e for e in events if e.agent_id == agent_id]
assert len(agent_events) >= 2
assert all(e.total_bytes == test_backup.size for e in agent_events)
# Verify events report distinct increasing byte counts
uploaded_bytes = [e.uploaded_bytes for e in agent_events]
assert uploaded_bytes == sorted(uploaded_bytes)
assert len(set(uploaded_bytes)) == len(uploaded_bytes)
# Verify at least one intermediate event (uploaded_bytes < total_bytes)
assert agent_events[0].uploaded_bytes < agent_events[0].total_bytes
async def test_agents_download(
hass_client: ClientSessionGenerator,
mock_client: MagicMock,
@@ -538,7 +605,7 @@ async def test_multipart_upload_uses_prefix_for_all_calls(
async def open_stream():
return stream()
await agent._upload_multipart("test.tar", open_stream)
await agent._upload_multipart("test.tar", open_stream, Mock())
prefixed_key = "ha/backups/test.tar"

View File

@@ -9,7 +9,12 @@ from unittest.mock import AsyncMock, Mock, patch
from botocore.exceptions import ConnectTimeoutError
import pytest
from homeassistant.components.backup import DOMAIN as BACKUP_DOMAIN, AgentBackup
from homeassistant.components.backup import (
DATA_MANAGER,
DOMAIN as BACKUP_DOMAIN,
AgentBackup,
UploadBackupEvent,
)
from homeassistant.components.idrive_e2.backup import (
MULTIPART_MIN_PART_SIZE_BYTES,
BotoCoreError,
@@ -412,7 +417,7 @@ async def test_multipart_upload_consistent_part_sizes(
mock_client.upload_part.side_effect = record_upload_part
await agent._upload_multipart("test.tar", open_stream)
await agent._upload_multipart("test.tar", open_stream, Mock())
# Verify that all non-trailing parts have the same size
assert len(uploaded_part_sizes) >= 2, "Expected at least 2 parts"
@@ -429,6 +434,68 @@ async def test_multipart_upload_consistent_part_sizes(
assert uploaded_part_sizes[-1] == expected_trailing
@pytest.mark.parametrize(
"agent_backup",
[MULTIPART_MIN_PART_SIZE_BYTES * 2],
indirect=True,
ids=["large"],
)
async def test_agents_upload_on_progress(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_client: MagicMock,
mock_config_entry: MockConfigEntry,
agent_backup: AgentBackup,
) -> None:
"""Test agent upload backup emits UploadBackupEvent via on_progress."""
client = await hass_client()
manager = hass.data[DATA_MANAGER]
events: list[UploadBackupEvent] = []
def _collect(event: UploadBackupEvent) -> None:
if isinstance(event, UploadBackupEvent):
events.append(event)
unsub = manager.async_subscribe_events(_collect)
with (
patch(
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
return_value=agent_backup,
),
patch(
"homeassistant.components.backup.manager.read_backup",
return_value=agent_backup,
),
patch("pathlib.Path.open") as mocked_open,
):
mocked_open.return_value.read = Mock(
side_effect=[
b"a" * agent_backup.size,
b"",
]
)
resp = await client.post(
f"/api/backup/upload?agent_id={DOMAIN}.{mock_config_entry.entry_id}",
data={"file": StringIO("test")},
)
unsub()
assert resp.status == 201
agent_id = f"{DOMAIN}.{mock_config_entry.entry_id}"
agent_events = [e for e in events if e.agent_id == agent_id]
assert len(agent_events) >= 2
assert all(e.total_bytes == agent_backup.size for e in agent_events)
# Verify events report distinct increasing byte counts
uploaded_bytes = [e.uploaded_bytes for e in agent_events]
assert uploaded_bytes == sorted(uploaded_bytes)
assert len(set(uploaded_bytes)) == len(uploaded_bytes)
# Verify at least one intermediate event (uploaded_bytes < total_bytes)
assert agent_events[0].uploaded_bytes < agent_events[0].total_bytes
async def test_agents_download(
hass_client: ClientSessionGenerator,
mock_client: MagicMock,