mirror of
https://github.com/home-assistant/core.git
synced 2026-04-02 08:26:41 +01:00
Add upload progress tracking to S3 integrations (#166325)
This commit is contained in:
@@ -147,7 +147,7 @@ class S3BackupAgent(BackupAgent):
|
||||
if backup.size < MULTIPART_MIN_PART_SIZE_BYTES:
|
||||
await self._upload_simple(tar_filename, open_stream)
|
||||
else:
|
||||
await self._upload_multipart(tar_filename, open_stream)
|
||||
await self._upload_multipart(tar_filename, open_stream, on_progress)
|
||||
|
||||
# Upload the metadata file
|
||||
metadata_content = json.dumps(backup.as_dict())
|
||||
@@ -188,11 +188,13 @@ class S3BackupAgent(BackupAgent):
|
||||
self,
|
||||
tar_filename: str,
|
||||
open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]],
|
||||
):
|
||||
on_progress: OnProgressCallback,
|
||||
) -> None:
|
||||
"""Upload a large file using multipart upload.
|
||||
|
||||
:param tar_filename: The target filename for the backup.
|
||||
:param open_stream: A function returning an async iterator that yields bytes.
|
||||
:param on_progress: A callback to report the number of uploaded bytes.
|
||||
"""
|
||||
_LOGGER.debug("Starting multipart upload for %s", tar_filename)
|
||||
multipart_upload = await self._client.create_multipart_upload(
|
||||
@@ -205,6 +207,7 @@ class S3BackupAgent(BackupAgent):
|
||||
part_number = 1
|
||||
buffer = bytearray() # bytes buffer to store the data
|
||||
offset = 0 # start index of unread data inside buffer
|
||||
bytes_uploaded = 0
|
||||
|
||||
stream = await open_stream()
|
||||
async for chunk in stream:
|
||||
@@ -233,6 +236,8 @@ class S3BackupAgent(BackupAgent):
|
||||
Body=part_data.tobytes(),
|
||||
)
|
||||
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
|
||||
bytes_uploaded += len(part_data)
|
||||
on_progress(bytes_uploaded=bytes_uploaded)
|
||||
part_number += 1
|
||||
finally:
|
||||
view.release()
|
||||
@@ -261,6 +266,8 @@ class S3BackupAgent(BackupAgent):
|
||||
Body=remaining_data.tobytes(),
|
||||
)
|
||||
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
|
||||
bytes_uploaded += len(remaining_data)
|
||||
on_progress(bytes_uploaded=bytes_uploaded)
|
||||
|
||||
await cast(Any, self._client).complete_multipart_upload(
|
||||
Bucket=self._bucket,
|
||||
|
||||
@@ -144,7 +144,7 @@ class R2BackupAgent(BackupAgent):
|
||||
if backup.size < MULTIPART_MIN_PART_SIZE_BYTES:
|
||||
await self._upload_simple(tar_filename, open_stream)
|
||||
else:
|
||||
await self._upload_multipart(tar_filename, open_stream)
|
||||
await self._upload_multipart(tar_filename, open_stream, on_progress)
|
||||
|
||||
# Upload the metadata file
|
||||
metadata_content = json.dumps(backup.as_dict())
|
||||
@@ -185,11 +185,13 @@ class R2BackupAgent(BackupAgent):
|
||||
self,
|
||||
tar_filename: str,
|
||||
open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]],
|
||||
):
|
||||
on_progress: OnProgressCallback,
|
||||
) -> None:
|
||||
"""Upload a large file using multipart upload.
|
||||
|
||||
:param tar_filename: The target filename for the backup.
|
||||
:param open_stream: A function returning an async iterator that yields bytes.
|
||||
:param on_progress: A callback to report the number of uploaded bytes.
|
||||
"""
|
||||
_LOGGER.debug("Starting multipart upload for %s", tar_filename)
|
||||
key = self._with_prefix(tar_filename)
|
||||
@@ -203,6 +205,7 @@ class R2BackupAgent(BackupAgent):
|
||||
part_number = 1
|
||||
buffer = bytearray() # bytes buffer to store the data
|
||||
offset = 0 # start index of unread data inside buffer
|
||||
bytes_uploaded = 0
|
||||
|
||||
stream = await open_stream()
|
||||
async for chunk in stream:
|
||||
@@ -231,6 +234,8 @@ class R2BackupAgent(BackupAgent):
|
||||
Body=part_data.tobytes(),
|
||||
)
|
||||
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
|
||||
bytes_uploaded += len(part_data)
|
||||
on_progress(bytes_uploaded=bytes_uploaded)
|
||||
part_number += 1
|
||||
finally:
|
||||
view.release()
|
||||
@@ -259,6 +264,8 @@ class R2BackupAgent(BackupAgent):
|
||||
Body=remaining_data.tobytes(),
|
||||
)
|
||||
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
|
||||
bytes_uploaded += len(remaining_data)
|
||||
on_progress(bytes_uploaded=bytes_uploaded)
|
||||
|
||||
await cast(Any, self._client).complete_multipart_upload(
|
||||
Bucket=self._bucket,
|
||||
|
||||
@@ -142,7 +142,7 @@ class IDriveE2BackupAgent(BackupAgent):
|
||||
if backup.size < MULTIPART_MIN_PART_SIZE_BYTES:
|
||||
await self._upload_simple(tar_filename, open_stream)
|
||||
else:
|
||||
await self._upload_multipart(tar_filename, open_stream)
|
||||
await self._upload_multipart(tar_filename, open_stream, on_progress)
|
||||
|
||||
# Upload the metadata file
|
||||
metadata_content = json.dumps(backup.as_dict())
|
||||
@@ -183,11 +183,13 @@ class IDriveE2BackupAgent(BackupAgent):
|
||||
self,
|
||||
tar_filename: str,
|
||||
open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]],
|
||||
on_progress: OnProgressCallback,
|
||||
) -> None:
|
||||
"""Upload a large file using multipart upload.
|
||||
|
||||
:param tar_filename: The target filename for the backup.
|
||||
:param open_stream: A function returning an async iterator that yields bytes.
|
||||
:param on_progress: A callback to report the number of uploaded bytes.
|
||||
"""
|
||||
_LOGGER.debug("Starting multipart upload for %s", tar_filename)
|
||||
multipart_upload = await cast(Any, self._client).create_multipart_upload(
|
||||
@@ -200,6 +202,7 @@ class IDriveE2BackupAgent(BackupAgent):
|
||||
part_number = 1
|
||||
buffer = bytearray() # bytes buffer to store the data
|
||||
offset = 0 # start index of unread data inside buffer
|
||||
bytes_uploaded = 0
|
||||
|
||||
stream = await open_stream()
|
||||
async for chunk in stream:
|
||||
@@ -228,6 +231,8 @@ class IDriveE2BackupAgent(BackupAgent):
|
||||
Body=part_data.tobytes(),
|
||||
)
|
||||
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
|
||||
bytes_uploaded += len(part_data)
|
||||
on_progress(bytes_uploaded=bytes_uploaded)
|
||||
part_number += 1
|
||||
finally:
|
||||
view.release()
|
||||
@@ -256,6 +261,8 @@ class IDriveE2BackupAgent(BackupAgent):
|
||||
Body=remaining_data.tobytes(),
|
||||
)
|
||||
parts.append({"PartNumber": part_number, "ETag": part["ETag"]})
|
||||
bytes_uploaded += len(remaining_data)
|
||||
on_progress(bytes_uploaded=bytes_uploaded)
|
||||
|
||||
await cast(Any, self._client).complete_multipart_upload(
|
||||
Bucket=self._bucket,
|
||||
|
||||
@@ -21,7 +21,12 @@ from homeassistant.components.aws_s3.const import (
|
||||
DATA_BACKUP_AGENT_LISTENERS,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.components.backup import DOMAIN as BACKUP_DOMAIN, AgentBackup
|
||||
from homeassistant.components.backup import (
|
||||
DATA_MANAGER,
|
||||
DOMAIN as BACKUP_DOMAIN,
|
||||
AgentBackup,
|
||||
UploadBackupEvent,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
@@ -332,6 +337,65 @@ async def test_agents_upload_network_failure(
|
||||
assert "Upload failed for aws_s3" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backup_size", [MULTIPART_MIN_PART_SIZE_BYTES * 2], ids=["large"]
|
||||
)
|
||||
async def test_agents_upload_on_progress(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_client: MagicMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_agent_backup: AgentBackup,
|
||||
) -> None:
|
||||
"""Test agent upload backup emits UploadBackupEvent via on_progress."""
|
||||
client = await hass_client()
|
||||
|
||||
manager = hass.data[DATA_MANAGER]
|
||||
events: list[UploadBackupEvent] = []
|
||||
|
||||
def _collect(event: UploadBackupEvent) -> None:
|
||||
if isinstance(event, UploadBackupEvent):
|
||||
events.append(event)
|
||||
|
||||
unsub = manager.async_subscribe_events(_collect)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
|
||||
return_value=mock_agent_backup,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.backup.manager.read_backup",
|
||||
return_value=mock_agent_backup,
|
||||
),
|
||||
patch("pathlib.Path.open") as mocked_open,
|
||||
):
|
||||
mocked_open.return_value.read = Mock(
|
||||
side_effect=[
|
||||
b"a" * mock_agent_backup.size,
|
||||
b"",
|
||||
]
|
||||
)
|
||||
resp = await client.post(
|
||||
f"/api/backup/upload?agent_id={DOMAIN}.{mock_config_entry.entry_id}",
|
||||
data={"file": StringIO("test")},
|
||||
)
|
||||
|
||||
unsub()
|
||||
|
||||
assert resp.status == 201
|
||||
agent_id = f"{DOMAIN}.{mock_config_entry.entry_id}"
|
||||
agent_events = [e for e in events if e.agent_id == agent_id]
|
||||
assert len(agent_events) >= 2
|
||||
assert all(e.total_bytes == mock_agent_backup.size for e in agent_events)
|
||||
# Verify events report distinct increasing byte counts
|
||||
uploaded_bytes = [e.uploaded_bytes for e in agent_events]
|
||||
assert uploaded_bytes == sorted(uploaded_bytes)
|
||||
assert len(set(uploaded_bytes)) == len(uploaded_bytes)
|
||||
# Verify at least one intermediate event (uploaded_bytes < total_bytes)
|
||||
assert agent_events[0].uploaded_bytes < agent_events[0].total_bytes
|
||||
|
||||
|
||||
async def test_agents_download(
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_client: MagicMock,
|
||||
|
||||
@@ -9,7 +9,12 @@ from unittest.mock import AsyncMock, Mock, patch
|
||||
from botocore.exceptions import BotoCoreError, ConnectTimeoutError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.backup import DOMAIN as BACKUP_DOMAIN, AgentBackup
|
||||
from homeassistant.components.backup import (
|
||||
DATA_MANAGER,
|
||||
DOMAIN as BACKUP_DOMAIN,
|
||||
AgentBackup,
|
||||
UploadBackupEvent,
|
||||
)
|
||||
from homeassistant.components.cloudflare_r2.backup import (
|
||||
MULTIPART_MIN_PART_SIZE_BYTES,
|
||||
R2BackupAgent,
|
||||
@@ -400,7 +405,7 @@ async def test_multipart_upload_consistent_part_sizes(
|
||||
|
||||
mock_client.upload_part.side_effect = record_upload_part
|
||||
|
||||
await agent._upload_multipart("test.tar", open_stream)
|
||||
await agent._upload_multipart("test.tar", open_stream, Mock())
|
||||
|
||||
# Verify that all non-trailing parts have the same size
|
||||
assert len(uploaded_part_sizes) >= 2, "Expected at least 2 parts"
|
||||
@@ -417,6 +422,68 @@ async def test_multipart_upload_consistent_part_sizes(
|
||||
assert uploaded_part_sizes[-1] == expected_trailing
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_backup",
|
||||
[MULTIPART_MIN_PART_SIZE_BYTES * 2],
|
||||
indirect=True,
|
||||
ids=["large"],
|
||||
)
|
||||
async def test_agents_upload_on_progress(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_client: MagicMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
test_backup: AgentBackup,
|
||||
) -> None:
|
||||
"""Test agent upload backup emits UploadBackupEvent via on_progress."""
|
||||
client = await hass_client()
|
||||
|
||||
manager = hass.data[DATA_MANAGER]
|
||||
events: list[UploadBackupEvent] = []
|
||||
|
||||
def _collect(event: UploadBackupEvent) -> None:
|
||||
if isinstance(event, UploadBackupEvent):
|
||||
events.append(event)
|
||||
|
||||
unsub = manager.async_subscribe_events(_collect)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
|
||||
return_value=test_backup,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.backup.manager.read_backup",
|
||||
return_value=test_backup,
|
||||
),
|
||||
patch("pathlib.Path.open") as mocked_open,
|
||||
):
|
||||
mocked_open.return_value.read = Mock(
|
||||
side_effect=[
|
||||
b"a" * test_backup.size,
|
||||
b"",
|
||||
]
|
||||
)
|
||||
resp = await client.post(
|
||||
f"/api/backup/upload?agent_id={DOMAIN}.{mock_config_entry.entry_id}",
|
||||
data={"file": StringIO("test")},
|
||||
)
|
||||
|
||||
unsub()
|
||||
|
||||
assert resp.status == 201
|
||||
agent_id = f"{DOMAIN}.{mock_config_entry.entry_id}"
|
||||
agent_events = [e for e in events if e.agent_id == agent_id]
|
||||
assert len(agent_events) >= 2
|
||||
assert all(e.total_bytes == test_backup.size for e in agent_events)
|
||||
# Verify events report distinct increasing byte counts
|
||||
uploaded_bytes = [e.uploaded_bytes for e in agent_events]
|
||||
assert uploaded_bytes == sorted(uploaded_bytes)
|
||||
assert len(set(uploaded_bytes)) == len(uploaded_bytes)
|
||||
# Verify at least one intermediate event (uploaded_bytes < total_bytes)
|
||||
assert agent_events[0].uploaded_bytes < agent_events[0].total_bytes
|
||||
|
||||
|
||||
async def test_agents_download(
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_client: MagicMock,
|
||||
@@ -538,7 +605,7 @@ async def test_multipart_upload_uses_prefix_for_all_calls(
|
||||
async def open_stream():
|
||||
return stream()
|
||||
|
||||
await agent._upload_multipart("test.tar", open_stream)
|
||||
await agent._upload_multipart("test.tar", open_stream, Mock())
|
||||
|
||||
prefixed_key = "ha/backups/test.tar"
|
||||
|
||||
|
||||
@@ -9,7 +9,12 @@ from unittest.mock import AsyncMock, Mock, patch
|
||||
from botocore.exceptions import ConnectTimeoutError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.backup import DOMAIN as BACKUP_DOMAIN, AgentBackup
|
||||
from homeassistant.components.backup import (
|
||||
DATA_MANAGER,
|
||||
DOMAIN as BACKUP_DOMAIN,
|
||||
AgentBackup,
|
||||
UploadBackupEvent,
|
||||
)
|
||||
from homeassistant.components.idrive_e2.backup import (
|
||||
MULTIPART_MIN_PART_SIZE_BYTES,
|
||||
BotoCoreError,
|
||||
@@ -412,7 +417,7 @@ async def test_multipart_upload_consistent_part_sizes(
|
||||
|
||||
mock_client.upload_part.side_effect = record_upload_part
|
||||
|
||||
await agent._upload_multipart("test.tar", open_stream)
|
||||
await agent._upload_multipart("test.tar", open_stream, Mock())
|
||||
|
||||
# Verify that all non-trailing parts have the same size
|
||||
assert len(uploaded_part_sizes) >= 2, "Expected at least 2 parts"
|
||||
@@ -429,6 +434,68 @@ async def test_multipart_upload_consistent_part_sizes(
|
||||
assert uploaded_part_sizes[-1] == expected_trailing
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_backup",
|
||||
[MULTIPART_MIN_PART_SIZE_BYTES * 2],
|
||||
indirect=True,
|
||||
ids=["large"],
|
||||
)
|
||||
async def test_agents_upload_on_progress(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_client: MagicMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
agent_backup: AgentBackup,
|
||||
) -> None:
|
||||
"""Test agent upload backup emits UploadBackupEvent via on_progress."""
|
||||
client = await hass_client()
|
||||
|
||||
manager = hass.data[DATA_MANAGER]
|
||||
events: list[UploadBackupEvent] = []
|
||||
|
||||
def _collect(event: UploadBackupEvent) -> None:
|
||||
if isinstance(event, UploadBackupEvent):
|
||||
events.append(event)
|
||||
|
||||
unsub = manager.async_subscribe_events(_collect)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.backup.manager.BackupManager.async_get_backup",
|
||||
return_value=agent_backup,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.backup.manager.read_backup",
|
||||
return_value=agent_backup,
|
||||
),
|
||||
patch("pathlib.Path.open") as mocked_open,
|
||||
):
|
||||
mocked_open.return_value.read = Mock(
|
||||
side_effect=[
|
||||
b"a" * agent_backup.size,
|
||||
b"",
|
||||
]
|
||||
)
|
||||
resp = await client.post(
|
||||
f"/api/backup/upload?agent_id={DOMAIN}.{mock_config_entry.entry_id}",
|
||||
data={"file": StringIO("test")},
|
||||
)
|
||||
|
||||
unsub()
|
||||
|
||||
assert resp.status == 201
|
||||
agent_id = f"{DOMAIN}.{mock_config_entry.entry_id}"
|
||||
agent_events = [e for e in events if e.agent_id == agent_id]
|
||||
assert len(agent_events) >= 2
|
||||
assert all(e.total_bytes == agent_backup.size for e in agent_events)
|
||||
# Verify events report distinct increasing byte counts
|
||||
uploaded_bytes = [e.uploaded_bytes for e in agent_events]
|
||||
assert uploaded_bytes == sorted(uploaded_bytes)
|
||||
assert len(set(uploaded_bytes)) == len(uploaded_bytes)
|
||||
# Verify at least one intermediate event (uploaded_bytes < total_bytes)
|
||||
assert agent_events[0].uploaded_bytes < agent_events[0].total_bytes
|
||||
|
||||
|
||||
async def test_agents_download(
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_client: MagicMock,
|
||||
|
||||
Reference in New Issue
Block a user