diff --git a/homeassistant/components/google_drive/backup.py b/homeassistant/components/google_drive/backup.py index e6967d95eaf..40ebc7c7cec 100644 --- a/homeassistant/components/google_drive/backup.py +++ b/homeassistant/components/google_drive/backup.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import AsyncIterator, Callable, Coroutine +from functools import wraps import logging from typing import Any @@ -84,8 +85,22 @@ class GoogleDriveBackupAgent(BackupAgent): :param open_stream: A function returning an async iterator that yields bytes. :param backup: Metadata about the backup that should be uploaded. """ + + @wraps(open_stream) + async def wrapped_open_stream() -> AsyncIterator[bytes]: + stream = await open_stream() + + async def _progress_stream() -> AsyncIterator[bytes]: + bytes_uploaded = 0 + async for chunk in stream: + yield chunk + bytes_uploaded += len(chunk) + on_progress(bytes_uploaded=bytes_uploaded) + + return _progress_stream() + try: - await self._client.async_upload_backup(open_stream, backup) + await self._client.async_upload_backup(wrapped_open_stream, backup) except (GoogleDriveApiError, HomeAssistantError, TimeoutError) as err: raise BackupAgentError(f"Failed to upload backup: {err}") from err diff --git a/tests/components/google_drive/test_backup.py b/tests/components/google_drive/test_backup.py index b731be0c34e..48e2b72878a 100644 --- a/tests/components/google_drive/test_backup.py +++ b/tests/components/google_drive/test_backup.py @@ -1,5 +1,6 @@ """Test the Google Drive backup platform.""" +from collections.abc import AsyncIterator from io import StringIO import json from typing import Any @@ -16,6 +17,7 @@ from homeassistant.components.backup import ( AgentBackup, ) from homeassistant.components.google_drive import DOMAIN +from homeassistant.components.google_drive.backup import GoogleDriveBackupAgent from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component @@ -59,6 +61,18 @@ TEST_AGENT_BACKUP_RESULT = { } +async def consume_stream( + file_metadata: Any, + open_stream: Any, + *args: Any, + **kwargs: Any, +) -> None: + """Consume the stream from the open_stream callable.""" + stream = await open_stream() + async for _ in stream: + pass + + @pytest.fixture(autouse=True) async def setup_integration( hass: HomeAssistant, @@ -283,7 +297,7 @@ async def test_agents_upload( snapshot: SnapshotAssertion, ) -> None: """Test agent upload backup.""" - mock_api.resumable_upload_file = AsyncMock(return_value=None) + mock_api.resumable_upload_file = AsyncMock(side_effect=consume_stream) client = await hass_client() @@ -324,7 +338,7 @@ async def test_agents_upload_create_folder_if_missing( mock_api.create_file = AsyncMock( return_value={"id": "new folder id", "name": "Home Assistant"} ) - mock_api.resumable_upload_file = AsyncMock(return_value=None) + mock_api.resumable_upload_file = AsyncMock(side_effect=consume_stream) client = await hass_client() @@ -354,6 +368,37 @@ async def test_agents_upload_create_folder_if_missing( assert [tuple(mock_call) for mock_call in mock_api.mock_calls] == snapshot +async def test_agents_upload_progress( + hass: HomeAssistant, + mock_api: MagicMock, +) -> None: + """Test agent upload reports progress.""" + mock_api.resumable_upload_file = AsyncMock(side_effect=consume_stream) + + entries = hass.config_entries.async_entries(DOMAIN) + agent = GoogleDriveBackupAgent(entries[0]) + + progress_calls = [] + + def on_progress(*, bytes_uploaded: int, **kwargs: Any) -> None: + progress_calls.append(bytes_uploaded) + + async def open_stream() -> AsyncIterator[bytes]: + async def stream() -> AsyncIterator[bytes]: + yield b"chunk1" + yield b"chunk2" + + return stream() + + await agent.async_upload_backup( + open_stream=open_stream, + backup=TEST_AGENT_BACKUP, + on_progress=on_progress, + ) + + assert progress_calls == [6, 12] + + async def test_agents_upload_fail( hass: HomeAssistant, hass_client: ClientSessionGenerator,