mirror of
https://github.com/home-assistant/supervisor.git
synced 2026-02-15 07:27:13 +00:00
Refactor Docker pull progress with registry manifest fetcher (#6379)
* Use count-based progress for Docker image pulls Refactor Docker image pull progress to use a simpler count-based approach where each layer contributes equally (100% / total_layers) regardless of size. This replaces the previous size-weighted calculation that was susceptible to progress regression. The core issue was that Docker rate-limits concurrent downloads (~3 at a time) and reports layer sizes only when downloading starts. With size- weighted progress, large layers appearing late would cause progress to drop dramatically (e.g., 59% -> 29%) as the total size increased. The new approach: - Each layer contributes equally to overall progress - Per-layer progress: 70% download weight, 30% extraction weight - Progress only starts after first "Downloading" event (when layer count is known) - Always caps at 99% - job completion handles final 100% This simplifies the code by moving progress tracking to a dedicated module (pull_progress.py) and removing complex size-based scaling logic that tried to account for unknown layer sizes. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Exclude already-existing layers from pull progress calculation Layers that already exist locally should not count towards download progress since there's nothing to download for them. Only layers that need pulling are included in the progress calculation. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Add registry manifest fetcher for size-based pull progress Fetch image manifests directly from container registries before pulling to get accurate layer sizes upfront. This enables size-weighted progress tracking where each layer contributes proportionally to its byte size, rather than equal weight per layer. Key changes: - Add RegistryManifestFetcher that handles auth discovery via WWW-Authenticate headers, token fetching with optional credentials, and multi-arch manifest list resolution - Update ImagePullProgress to accept manifest layer sizes via set_manifest() and calculate size-weighted progress - Fall back to count-based progress when manifest fetch fails - Pre-populate layer sizes from manifest when creating layer trackers The manifest fetcher supports ghcr.io, Docker Hub, and private registries by using credentials from Docker config when available. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Clamp progress to 100 to prevent floating point precision issues Floating point arithmetic in weighted progress calculations can produce values slightly above 100 (e.g., 100.00000000000001). This causes validation errors when the progress value is checked. Add min(100, ...) clamping to both size-weighted and count-based progress calculations to ensure the result never exceeds 100. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Use sys_websession for manifest fetcher instead of creating new session Reuse the existing CoreSys websession for registry manifest requests instead of creating a new aiohttp session. This improves performance and follows the established pattern used throughout the codebase. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Make platform parameter required and warn on missing platform - Make platform a required parameter in get_manifest() and _fetch_manifest() since it's always provided by the calling code - Return None and log warning when requested platform is not found in multi-arch manifest list, instead of falling back to first manifest which could be the wrong architecture 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Log manifest fetch failures at warning level Users will notice degraded progress tracking when manifest fetch fails, so log at warning level to help diagnose issues. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Add pylint disable comments for protected access in manifest tests 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Separate download_current and total_size updates in pull progress Update download_current and total_size independently in the DOWNLOADING handler. This ensures download_current is updated even when total is not yet available. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Reject invalid platform format in manifest selection --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -2,13 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, StrEnum
|
||||
from functools import total_ordering
|
||||
from enum import StrEnum
|
||||
from pathlib import PurePath
|
||||
import re
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from ..const import MACHINE_ID
|
||||
|
||||
@@ -81,57 +79,6 @@ class PropagationMode(StrEnum):
|
||||
RSLAVE = "rslave"
|
||||
|
||||
|
||||
@total_ordering
|
||||
class PullImageLayerStage(Enum):
|
||||
"""Job stages for pulling an image layer.
|
||||
|
||||
These are a subset of the statuses in a docker pull image log. They
|
||||
are the standardized ones that are the most useful to us.
|
||||
"""
|
||||
|
||||
PULLING_FS_LAYER = 1, "Pulling fs layer"
|
||||
RETRYING_DOWNLOAD = 2, "Retrying download"
|
||||
DOWNLOADING = 2, "Downloading"
|
||||
VERIFYING_CHECKSUM = 3, "Verifying Checksum"
|
||||
DOWNLOAD_COMPLETE = 4, "Download complete"
|
||||
EXTRACTING = 5, "Extracting"
|
||||
PULL_COMPLETE = 6, "Pull complete"
|
||||
|
||||
def __init__(self, order: int, status: str) -> None:
|
||||
"""Set fields from values."""
|
||||
self.order = order
|
||||
self.status = status
|
||||
|
||||
def __eq__(self, value: object, /) -> bool:
|
||||
"""Check equality, allow StrEnum style comparisons on status."""
|
||||
with suppress(AttributeError):
|
||||
return self.status == cast(PullImageLayerStage, value).status
|
||||
return self.status == value
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
"""Order instances."""
|
||||
with suppress(AttributeError):
|
||||
return self.order < cast(PullImageLayerStage, other).order
|
||||
return False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash instance."""
|
||||
return hash(self.status)
|
||||
|
||||
@classmethod
|
||||
def from_status(cls, status: str) -> PullImageLayerStage | None:
|
||||
"""Return stage instance from pull log status."""
|
||||
for i in cls:
|
||||
if i.status == status:
|
||||
return i
|
||||
|
||||
# This one includes number of seconds until download so its not constant
|
||||
if RE_RETRYING_DOWNLOAD_STATUS.match(status):
|
||||
return cls.RETRYING_DOWNLOAD
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class MountBindOptions:
|
||||
"""Bind options for docker mount."""
|
||||
|
||||
@@ -9,15 +9,15 @@ from contextlib import suppress
|
||||
from http import HTTPStatus
|
||||
import logging
|
||||
from time import time
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import aiodocker
|
||||
import aiohttp
|
||||
from awesomeversion import AwesomeVersion
|
||||
from awesomeversion.strategy import AwesomeVersionStrategy
|
||||
import requests
|
||||
|
||||
from ..bus import EventListener
|
||||
from ..const import (
|
||||
ATTR_PASSWORD,
|
||||
ATTR_REGISTRY,
|
||||
@@ -33,25 +33,18 @@ from ..exceptions import (
|
||||
DockerError,
|
||||
DockerHubRateLimitExceeded,
|
||||
DockerJobError,
|
||||
DockerLogOutOfOrder,
|
||||
DockerNotFound,
|
||||
DockerRequestError,
|
||||
)
|
||||
from ..jobs import SupervisorJob
|
||||
from ..jobs.const import JOB_GROUP_DOCKER_INTERFACE, JobConcurrency
|
||||
from ..jobs.decorator import Job
|
||||
from ..jobs.job_group import JobGroup
|
||||
from ..resolution.const import ContextType, IssueType, SuggestionType
|
||||
from ..utils.sentry import async_capture_exception
|
||||
from .const import (
|
||||
DOCKER_HUB,
|
||||
DOCKER_HUB_LEGACY,
|
||||
ContainerState,
|
||||
PullImageLayerStage,
|
||||
RestartPolicy,
|
||||
)
|
||||
from .const import DOCKER_HUB, DOCKER_HUB_LEGACY, ContainerState, RestartPolicy
|
||||
from .manager import CommandReturn, ExecReturn, PullLogEntry
|
||||
from .monitor import DockerContainerStateEvent
|
||||
from .pull_progress import ImagePullProgress
|
||||
from .stats import DockerStats
|
||||
|
||||
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
||||
@@ -203,159 +196,6 @@ class DockerInterface(JobGroup, ABC):
|
||||
|
||||
return credentials
|
||||
|
||||
def _process_pull_image_log( # noqa: C901
|
||||
self, install_job_id: str, reference: PullLogEntry
|
||||
) -> None:
|
||||
"""Process events fired from a docker while pulling an image, filtered to a given job id."""
|
||||
if (
|
||||
reference.job_id != install_job_id
|
||||
or not reference.id
|
||||
or not reference.status
|
||||
or not (stage := PullImageLayerStage.from_status(reference.status))
|
||||
):
|
||||
return
|
||||
|
||||
# Pulling FS Layer is our marker for a layer that needs to be downloaded and extracted. Otherwise it already exists and we can ignore
|
||||
job: SupervisorJob | None = None
|
||||
if stage == PullImageLayerStage.PULLING_FS_LAYER:
|
||||
job = self.sys_jobs.new_job(
|
||||
name="Pulling container image layer",
|
||||
initial_stage=stage.status,
|
||||
reference=reference.id,
|
||||
parent_id=install_job_id,
|
||||
internal=True,
|
||||
)
|
||||
job.done = False
|
||||
return
|
||||
|
||||
# Find our sub job to update details of
|
||||
for j in self.sys_jobs.jobs:
|
||||
if j.parent_id == install_job_id and j.reference == reference.id:
|
||||
job = j
|
||||
break
|
||||
|
||||
# There should no longer be any real risk of logs out of order anymore.
|
||||
# However tests with very small images have shown that sometimes Docker
|
||||
# skips stages in log. So keeping this one as a safety check on null job
|
||||
if not job:
|
||||
raise DockerLogOutOfOrder(
|
||||
f"Received pull image log with status {reference.status} for image id {reference.id} and parent job {install_job_id} but could not find a matching job, skipping",
|
||||
_LOGGER.debug,
|
||||
)
|
||||
|
||||
# For progress calculation we assume downloading is 70% of time, extracting is 30% and others stages negligible
|
||||
progress = job.progress
|
||||
match stage:
|
||||
case PullImageLayerStage.DOWNLOADING | PullImageLayerStage.EXTRACTING:
|
||||
if (
|
||||
reference.progress_detail
|
||||
and reference.progress_detail.current
|
||||
and reference.progress_detail.total
|
||||
):
|
||||
progress = (
|
||||
reference.progress_detail.current
|
||||
/ reference.progress_detail.total
|
||||
)
|
||||
if stage == PullImageLayerStage.DOWNLOADING:
|
||||
progress = 70 * progress
|
||||
else:
|
||||
progress = 70 + 30 * progress
|
||||
case (
|
||||
PullImageLayerStage.VERIFYING_CHECKSUM
|
||||
| PullImageLayerStage.DOWNLOAD_COMPLETE
|
||||
):
|
||||
progress = 70
|
||||
case PullImageLayerStage.PULL_COMPLETE:
|
||||
progress = 100
|
||||
case PullImageLayerStage.RETRYING_DOWNLOAD:
|
||||
progress = 0
|
||||
|
||||
# No real risk of getting things out of order in current implementation
|
||||
# but keeping this one in case another change to these trips us up.
|
||||
if stage != PullImageLayerStage.RETRYING_DOWNLOAD and progress < job.progress:
|
||||
raise DockerLogOutOfOrder(
|
||||
f"Received pull image log with status {reference.status} for job {job.uuid} that implied progress was {progress} but current progress is {job.progress}, skipping",
|
||||
_LOGGER.debug,
|
||||
)
|
||||
|
||||
# Our filters have all passed. Time to update the job
|
||||
# Only downloading and extracting have progress details. Use that to set extra
|
||||
# We'll leave it around on later stages as the total bytes may be useful after that stage
|
||||
# Enforce range to prevent float drift error
|
||||
progress = max(0, min(progress, 100))
|
||||
if (
|
||||
stage in {PullImageLayerStage.DOWNLOADING, PullImageLayerStage.EXTRACTING}
|
||||
and reference.progress_detail
|
||||
and reference.progress_detail.current is not None
|
||||
and reference.progress_detail.total is not None
|
||||
):
|
||||
job.update(
|
||||
progress=progress,
|
||||
stage=stage.status,
|
||||
extra={
|
||||
"current": reference.progress_detail.current,
|
||||
"total": reference.progress_detail.total,
|
||||
},
|
||||
)
|
||||
else:
|
||||
# If we reach DOWNLOAD_COMPLETE without ever having set extra (small layers that skip
|
||||
# the downloading phase), set a minimal extra so aggregate progress calculation can proceed
|
||||
extra = job.extra
|
||||
if stage == PullImageLayerStage.DOWNLOAD_COMPLETE and not job.extra:
|
||||
extra = {"current": 1, "total": 1}
|
||||
|
||||
job.update(
|
||||
progress=progress,
|
||||
stage=stage.status,
|
||||
done=stage == PullImageLayerStage.PULL_COMPLETE,
|
||||
extra=None if stage == PullImageLayerStage.RETRYING_DOWNLOAD else extra,
|
||||
)
|
||||
|
||||
# Once we have received a progress update for every child job, start to set status of the main one
|
||||
install_job = self.sys_jobs.get_job(install_job_id)
|
||||
layer_jobs = [
|
||||
job
|
||||
for job in self.sys_jobs.jobs
|
||||
if job.parent_id == install_job.uuid
|
||||
and job.name == "Pulling container image layer"
|
||||
]
|
||||
|
||||
# First set the total bytes to be downloaded/extracted on the main job
|
||||
if not install_job.extra:
|
||||
total = 0
|
||||
for job in layer_jobs:
|
||||
if not job.extra:
|
||||
return
|
||||
total += job.extra["total"]
|
||||
install_job.extra = {"total": total}
|
||||
else:
|
||||
total = install_job.extra["total"]
|
||||
|
||||
# Then determine total progress based on progress of each sub-job, factoring in size of each compared to total
|
||||
progress = 0.0
|
||||
stage = PullImageLayerStage.PULL_COMPLETE
|
||||
for job in layer_jobs:
|
||||
if not job.extra or not job.extra.get("total"):
|
||||
return
|
||||
progress += job.progress * (job.extra["total"] / total)
|
||||
job_stage = PullImageLayerStage.from_status(cast(str, job.stage))
|
||||
|
||||
if job_stage < PullImageLayerStage.EXTRACTING:
|
||||
stage = PullImageLayerStage.DOWNLOADING
|
||||
elif (
|
||||
stage == PullImageLayerStage.PULL_COMPLETE
|
||||
and job_stage < PullImageLayerStage.PULL_COMPLETE
|
||||
):
|
||||
stage = PullImageLayerStage.EXTRACTING
|
||||
|
||||
# Ensure progress is 100 at this point to prevent float drift
|
||||
if stage == PullImageLayerStage.PULL_COMPLETE:
|
||||
progress = 100
|
||||
|
||||
# To reduce noise, limit updates to when result has changed by an entire percent or when stage changed
|
||||
if stage != install_job.stage or progress >= install_job.progress + 1:
|
||||
install_job.update(stage=stage.status, progress=max(0, min(progress, 100)))
|
||||
|
||||
@Job(
|
||||
name="docker_interface_install",
|
||||
on_condition=DockerJobError,
|
||||
@@ -375,48 +215,55 @@ class DockerInterface(JobGroup, ABC):
|
||||
raise ValueError("Cannot pull without an image!")
|
||||
|
||||
image_arch = arch or self.sys_arch.supervisor
|
||||
listener: EventListener | None = None
|
||||
platform = MAP_ARCH[image_arch]
|
||||
pull_progress = ImagePullProgress()
|
||||
current_job = self.sys_jobs.current
|
||||
|
||||
# Try to fetch manifest for accurate size-based progress
|
||||
# This is optional - if it fails, we fall back to count-based progress
|
||||
try:
|
||||
manifest = await self.sys_docker.manifest_fetcher.get_manifest(
|
||||
image, str(version), platform=platform
|
||||
)
|
||||
if manifest:
|
||||
pull_progress.set_manifest(manifest)
|
||||
_LOGGER.debug(
|
||||
"Using manifest for progress: %d layers, %d bytes",
|
||||
manifest.layer_count,
|
||||
manifest.total_size,
|
||||
)
|
||||
except (aiohttp.ClientError, TimeoutError) as err:
|
||||
_LOGGER.warning("Could not fetch manifest for progress: %s", err)
|
||||
|
||||
async def process_pull_event(event: PullLogEntry) -> None:
|
||||
"""Process pull event and update job progress."""
|
||||
if event.job_id != current_job.uuid:
|
||||
return
|
||||
|
||||
# Process event through progress tracker
|
||||
pull_progress.process_event(event)
|
||||
|
||||
# Update job if progress changed significantly (>= 1%)
|
||||
should_update, progress = pull_progress.should_update_job()
|
||||
if should_update:
|
||||
stage = pull_progress.get_stage()
|
||||
current_job.update(progress=progress, stage=stage)
|
||||
|
||||
listener = self.sys_bus.register_event(
|
||||
BusEvent.DOCKER_IMAGE_PULL_UPDATE, process_pull_event
|
||||
)
|
||||
|
||||
_LOGGER.info("Downloading docker image %s with tag %s.", image, version)
|
||||
try:
|
||||
# Get credentials for private registries to pass to aiodocker
|
||||
credentials = self._get_credentials(image) or None
|
||||
|
||||
curr_job_id = self.sys_jobs.current.uuid
|
||||
|
||||
async def process_pull_image_log(reference: PullLogEntry) -> None:
|
||||
try:
|
||||
self._process_pull_image_log(curr_job_id, reference)
|
||||
except DockerLogOutOfOrder as err:
|
||||
# Send all these to sentry. Missing a few progress updates
|
||||
# shouldn't matter to users but matters to us
|
||||
await async_capture_exception(err)
|
||||
except ValueError as err:
|
||||
# Catch "Cannot update a job that is done" errors which occur under
|
||||
# some not clearly understood combination of events. Log with context
|
||||
# and send to Sentry to track frequency and gather debugging info.
|
||||
if "Cannot update a job that is done" in str(err):
|
||||
_LOGGER.warning(
|
||||
"Unexpected job state during pull: %s (layer: %s, status: %s, progress: %s)",
|
||||
err,
|
||||
reference.id,
|
||||
reference.status,
|
||||
reference.progress,
|
||||
)
|
||||
await async_capture_exception(err)
|
||||
else:
|
||||
raise
|
||||
|
||||
listener = self.sys_bus.register_event(
|
||||
BusEvent.DOCKER_IMAGE_PULL_UPDATE, process_pull_image_log
|
||||
)
|
||||
|
||||
# Pull new image, passing credentials to aiodocker
|
||||
docker_image = await self.sys_docker.pull_image(
|
||||
self.sys_jobs.current.uuid,
|
||||
current_job.uuid,
|
||||
image,
|
||||
str(version),
|
||||
platform=MAP_ARCH[image_arch],
|
||||
platform=platform,
|
||||
auth=credentials,
|
||||
)
|
||||
|
||||
@@ -441,8 +288,7 @@ class DockerInterface(JobGroup, ABC):
|
||||
f"Can't install {image}:{version!s}: {err}", _LOGGER.error
|
||||
) from err
|
||||
finally:
|
||||
if listener:
|
||||
self.sys_bus.remove_listener(listener)
|
||||
self.sys_bus.remove_listener(listener)
|
||||
|
||||
self._meta = docker_image
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ from .const import (
|
||||
RestartPolicy,
|
||||
Ulimit,
|
||||
)
|
||||
from .manifest import RegistryManifestFetcher
|
||||
from .monitor import DockerMonitor
|
||||
from .network import DockerNetwork
|
||||
from .utils import get_registry_from_image
|
||||
@@ -275,6 +276,9 @@ class DockerAPI(CoreSysAttributes):
|
||||
self._info: DockerInfo | None = None
|
||||
self.config: DockerConfig = DockerConfig()
|
||||
self._monitor: DockerMonitor = DockerMonitor(coresys)
|
||||
self._manifest_fetcher: RegistryManifestFetcher = RegistryManifestFetcher(
|
||||
coresys
|
||||
)
|
||||
|
||||
async def post_init(self) -> Self:
|
||||
"""Post init actions that must be done in event loop."""
|
||||
@@ -335,6 +339,11 @@ class DockerAPI(CoreSysAttributes):
|
||||
"""Return docker events monitor."""
|
||||
return self._monitor
|
||||
|
||||
@property
|
||||
def manifest_fetcher(self) -> RegistryManifestFetcher:
|
||||
"""Return manifest fetcher for registry access."""
|
||||
return self._manifest_fetcher
|
||||
|
||||
async def load(self) -> None:
|
||||
"""Start docker events monitor."""
|
||||
await self.monitor.load()
|
||||
|
||||
339
supervisor/docker/manifest.py
Normal file
339
supervisor/docker/manifest.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""Docker registry manifest fetcher.
|
||||
|
||||
Fetches image manifests directly from container registries to get layer sizes
|
||||
before pulling an image. This enables accurate size-based progress tracking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
|
||||
from supervisor.docker.utils import get_registry_from_image
|
||||
|
||||
from .const import DOCKER_HUB, DOCKER_HUB_LEGACY
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..coresys import CoreSys
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Media types for manifest requests
|
||||
MANIFEST_MEDIA_TYPES = (
|
||||
"application/vnd.docker.distribution.manifest.v2+json",
|
||||
"application/vnd.oci.image.manifest.v1+json",
|
||||
"application/vnd.docker.distribution.manifest.list.v2+json",
|
||||
"application/vnd.oci.image.index.v1+json",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageManifest:
|
||||
"""Container image manifest with layer information."""
|
||||
|
||||
digest: str
|
||||
total_size: int
|
||||
layers: dict[str, int] # digest -> size in bytes
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
"""Return number of layers."""
|
||||
return len(self.layers)
|
||||
|
||||
|
||||
def parse_image_reference(image: str, tag: str) -> tuple[str, str, str]:
|
||||
"""Parse image reference into (registry, repository, tag).
|
||||
|
||||
Examples:
|
||||
ghcr.io/home-assistant/home-assistant:2025.1.0
|
||||
-> (ghcr.io, home-assistant/home-assistant, 2025.1.0)
|
||||
homeassistant/home-assistant:latest
|
||||
-> (registry-1.docker.io, homeassistant/home-assistant, latest)
|
||||
alpine:3.18
|
||||
-> (registry-1.docker.io, library/alpine, 3.18)
|
||||
|
||||
"""
|
||||
# Check if image has explicit registry host
|
||||
registry = get_registry_from_image(image)
|
||||
if registry:
|
||||
repository = image[len(registry) + 1 :] # Remove "registry/" prefix
|
||||
else:
|
||||
registry = DOCKER_HUB
|
||||
repository = image
|
||||
# Docker Hub requires "library/" prefix for official images
|
||||
if "/" not in repository:
|
||||
repository = f"library/{repository}"
|
||||
|
||||
return registry, repository, tag
|
||||
|
||||
|
||||
class RegistryManifestFetcher:
|
||||
"""Fetches manifests from container registries."""
|
||||
|
||||
def __init__(self, coresys: CoreSys) -> None:
|
||||
"""Initialize the fetcher."""
|
||||
self.coresys = coresys
|
||||
|
||||
@property
|
||||
def _session(self) -> aiohttp.ClientSession:
|
||||
"""Return the websession for HTTP requests."""
|
||||
return self.coresys.websession
|
||||
|
||||
def _get_credentials(self, registry: str) -> tuple[str, str] | None:
|
||||
"""Get credentials for registry from Docker config.
|
||||
|
||||
Returns (username, password) tuple or None if no credentials.
|
||||
"""
|
||||
registries = self.coresys.docker.config.registries
|
||||
|
||||
# Map registry hostname to config key
|
||||
# Docker Hub can be stored as "hub.docker.com" in config
|
||||
if registry in (DOCKER_HUB, DOCKER_HUB_LEGACY):
|
||||
if DOCKER_HUB in registries:
|
||||
creds = registries[DOCKER_HUB]
|
||||
return creds.get("username"), creds.get("password")
|
||||
elif registry in registries:
|
||||
creds = registries[registry]
|
||||
return creds.get("username"), creds.get("password")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_auth_token(
|
||||
self,
|
||||
registry: str,
|
||||
repository: str,
|
||||
) -> str | None:
|
||||
"""Get authentication token for registry.
|
||||
|
||||
Uses the WWW-Authenticate header from a 401 response to discover
|
||||
the token endpoint, then requests a token with appropriate scope.
|
||||
"""
|
||||
# First, make an unauthenticated request to get WWW-Authenticate header
|
||||
manifest_url = f"https://{registry}/v2/{repository}/manifests/latest"
|
||||
|
||||
try:
|
||||
async with self._session.get(manifest_url) as resp:
|
||||
if resp.status == 200:
|
||||
# No auth required
|
||||
return None
|
||||
|
||||
if resp.status != 401:
|
||||
_LOGGER.warning(
|
||||
"Unexpected status %d from registry %s", resp.status, registry
|
||||
)
|
||||
return None
|
||||
|
||||
www_auth = resp.headers.get("WWW-Authenticate", "")
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to connect to registry %s: %s", registry, err)
|
||||
return None
|
||||
|
||||
# Parse WWW-Authenticate: Bearer realm="...",service="...",scope="..."
|
||||
if not www_auth.startswith("Bearer "):
|
||||
_LOGGER.warning("Unsupported auth type from %s: %s", registry, www_auth)
|
||||
return None
|
||||
|
||||
params = {}
|
||||
for match in re.finditer(r'(\w+)="([^"]*)"', www_auth):
|
||||
params[match.group(1)] = match.group(2)
|
||||
|
||||
realm = params.get("realm")
|
||||
service = params.get("service")
|
||||
|
||||
if not realm:
|
||||
_LOGGER.warning("No realm in WWW-Authenticate from %s", registry)
|
||||
return None
|
||||
|
||||
# Build token request URL
|
||||
token_url = f"{realm}?scope=repository:{repository}:pull"
|
||||
if service:
|
||||
token_url += f"&service={service}"
|
||||
|
||||
# Check for credentials
|
||||
auth = None
|
||||
credentials = self._get_credentials(registry)
|
||||
if credentials:
|
||||
username, password = credentials
|
||||
if username and password:
|
||||
auth = aiohttp.BasicAuth(username, password)
|
||||
_LOGGER.debug("Using credentials for %s", registry)
|
||||
|
||||
try:
|
||||
async with self._session.get(token_url, auth=auth) as resp:
|
||||
if resp.status != 200:
|
||||
_LOGGER.warning(
|
||||
"Failed to get token from %s: %d", realm, resp.status
|
||||
)
|
||||
return None
|
||||
|
||||
data = await resp.json()
|
||||
return data.get("token") or data.get("access_token")
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to get auth token: %s", err)
|
||||
return None
|
||||
|
||||
async def _fetch_manifest(
|
||||
self,
|
||||
registry: str,
|
||||
repository: str,
|
||||
reference: str,
|
||||
token: str | None,
|
||||
platform: str,
|
||||
) -> dict | None:
|
||||
"""Fetch manifest from registry.
|
||||
|
||||
If the manifest is a manifest list (multi-arch), fetches the
|
||||
platform-specific manifest.
|
||||
"""
|
||||
manifest_url = f"https://{registry}/v2/{repository}/manifests/{reference}"
|
||||
|
||||
headers = {"Accept": ", ".join(MANIFEST_MEDIA_TYPES)}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
try:
|
||||
async with self._session.get(manifest_url, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
_LOGGER.warning(
|
||||
"Failed to fetch manifest for %s/%s:%s - %d",
|
||||
registry,
|
||||
repository,
|
||||
reference,
|
||||
resp.status,
|
||||
)
|
||||
return None
|
||||
|
||||
manifest = await resp.json()
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to fetch manifest: %s", err)
|
||||
return None
|
||||
|
||||
media_type = manifest.get("mediaType", "")
|
||||
|
||||
# Check if this is a manifest list (multi-arch image)
|
||||
if "list" in media_type or "index" in media_type:
|
||||
manifests = manifest.get("manifests", [])
|
||||
if not manifests:
|
||||
_LOGGER.warning("Empty manifest list for %s/%s", registry, repository)
|
||||
return None
|
||||
|
||||
# Platform format is "linux/amd64", "linux/arm64", etc.
|
||||
parts = platform.split("/")
|
||||
if len(parts) < 2:
|
||||
_LOGGER.warning("Invalid platform format: %s", platform)
|
||||
return None
|
||||
|
||||
target_os, target_arch = parts[0], parts[1]
|
||||
|
||||
platform_manifest = None
|
||||
for m in manifests:
|
||||
plat = m.get("platform", {})
|
||||
if (
|
||||
plat.get("os") == target_os
|
||||
and plat.get("architecture") == target_arch
|
||||
):
|
||||
platform_manifest = m
|
||||
break
|
||||
|
||||
if not platform_manifest:
|
||||
_LOGGER.warning(
|
||||
"Platform %s/%s not found in manifest list for %s/%s, "
|
||||
"cannot use manifest for progress tracking",
|
||||
target_os,
|
||||
target_arch,
|
||||
registry,
|
||||
repository,
|
||||
)
|
||||
return None
|
||||
|
||||
# Fetch the platform-specific manifest
|
||||
return await self._fetch_manifest(
|
||||
registry,
|
||||
repository,
|
||||
platform_manifest["digest"],
|
||||
token,
|
||||
platform,
|
||||
)
|
||||
|
||||
return manifest
|
||||
|
||||
async def get_manifest(
|
||||
self,
|
||||
image: str,
|
||||
tag: str,
|
||||
platform: str,
|
||||
) -> ImageManifest | None:
|
||||
"""Fetch manifest and extract layer sizes.
|
||||
|
||||
Args:
|
||||
image: Image name (e.g., "ghcr.io/home-assistant/home-assistant")
|
||||
tag: Image tag (e.g., "2025.1.0")
|
||||
platform: Target platform (e.g., "linux/amd64")
|
||||
|
||||
Returns:
|
||||
ImageManifest with layer sizes, or None if fetch failed.
|
||||
|
||||
"""
|
||||
registry, repository, tag = parse_image_reference(image, tag)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Fetching manifest for %s/%s:%s (platform=%s)",
|
||||
registry,
|
||||
repository,
|
||||
tag,
|
||||
platform,
|
||||
)
|
||||
|
||||
# Get auth token
|
||||
token = await self._get_auth_token(registry, repository)
|
||||
|
||||
# Fetch manifest
|
||||
manifest = await self._fetch_manifest(
|
||||
registry, repository, tag, token, platform
|
||||
)
|
||||
|
||||
if not manifest:
|
||||
return None
|
||||
|
||||
# Extract layer information
|
||||
layers = manifest.get("layers", [])
|
||||
if not layers:
|
||||
_LOGGER.warning(
|
||||
"No layers in manifest for %s/%s:%s", registry, repository, tag
|
||||
)
|
||||
return None
|
||||
|
||||
layer_sizes: dict[str, int] = {}
|
||||
total_size = 0
|
||||
|
||||
for layer in layers:
|
||||
digest = layer.get("digest", "")
|
||||
size = layer.get("size", 0)
|
||||
if digest and size:
|
||||
# Store by short digest (first 12 chars after sha256:)
|
||||
short_digest = (
|
||||
digest.split(":")[1][:12] if ":" in digest else digest[:12]
|
||||
)
|
||||
layer_sizes[short_digest] = size
|
||||
total_size += size
|
||||
|
||||
digest = manifest.get("config", {}).get("digest", "")
|
||||
|
||||
_LOGGER.debug(
|
||||
"Manifest for %s/%s:%s - %d layers, %d bytes total",
|
||||
registry,
|
||||
repository,
|
||||
tag,
|
||||
len(layer_sizes),
|
||||
total_size,
|
||||
)
|
||||
|
||||
return ImageManifest(
|
||||
digest=digest,
|
||||
total_size=total_size,
|
||||
layers=layer_sizes,
|
||||
)
|
||||
368
supervisor/docker/pull_progress.py
Normal file
368
supervisor/docker/pull_progress.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""Image pull progress tracking."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .manager import PullLogEntry
|
||||
from .manifest import ImageManifest
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Progress weight distribution: 70% downloading, 30% extraction
|
||||
DOWNLOAD_WEIGHT = 70.0
|
||||
EXTRACT_WEIGHT = 30.0
|
||||
|
||||
|
||||
class LayerPullStatus(Enum):
|
||||
"""Status values for pulling an image layer.
|
||||
|
||||
These are a subset of the statuses in a docker pull image log.
|
||||
The order field allows comparing which stage is further along.
|
||||
"""
|
||||
|
||||
PULLING_FS_LAYER = 1, "Pulling fs layer"
|
||||
WAITING = 1, "Waiting"
|
||||
RETRYING = 2, "Retrying" # Matches "Retrying in N seconds"
|
||||
DOWNLOADING = 3, "Downloading"
|
||||
VERIFYING_CHECKSUM = 4, "Verifying Checksum"
|
||||
DOWNLOAD_COMPLETE = 5, "Download complete"
|
||||
EXTRACTING = 6, "Extracting"
|
||||
PULL_COMPLETE = 7, "Pull complete"
|
||||
ALREADY_EXISTS = 7, "Already exists"
|
||||
|
||||
def __init__(self, order: int, status: str) -> None:
|
||||
"""Set fields from values."""
|
||||
self.order = order
|
||||
self.status = status
|
||||
|
||||
def __eq__(self, value: object, /) -> bool:
|
||||
"""Check equality, allow string comparisons on status."""
|
||||
with suppress(AttributeError):
|
||||
return self.status == cast(LayerPullStatus, value).status
|
||||
return self.status == value
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash based on status string."""
|
||||
return hash(self.status)
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
"""Order instances by stage progression."""
|
||||
with suppress(AttributeError):
|
||||
return self.order < cast(LayerPullStatus, other).order
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def from_status(cls, status: str) -> LayerPullStatus | None:
|
||||
"""Get enum from status string, or None if not recognized."""
|
||||
# Handle "Retrying in N seconds" pattern
|
||||
if status.startswith("Retrying in "):
|
||||
return cls.RETRYING
|
||||
for member in cls:
|
||||
if member.status == status:
|
||||
return member
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerProgress:
|
||||
"""Track progress of a single layer."""
|
||||
|
||||
layer_id: str
|
||||
total_size: int = 0 # Size in bytes (from downloading, reused for extraction)
|
||||
download_current: int = 0
|
||||
extract_current: int = 0 # Extraction progress in bytes (overlay2 only)
|
||||
download_complete: bool = False
|
||||
extract_complete: bool = False
|
||||
already_exists: bool = False # Layer was already locally available
|
||||
|
||||
def calculate_progress(self) -> float:
|
||||
"""Calculate layer progress 0-100.
|
||||
|
||||
Progress is weighted: 70% download, 30% extraction.
|
||||
For overlay2, we have byte-based extraction progress.
|
||||
For containerd, extraction jumps from 70% to 100% on completion.
|
||||
"""
|
||||
if self.already_exists or self.extract_complete:
|
||||
return 100.0
|
||||
|
||||
if self.download_complete:
|
||||
# Check if we have extraction progress (overlay2)
|
||||
if self.extract_current > 0 and self.total_size > 0:
|
||||
extract_pct = min(1.0, self.extract_current / self.total_size)
|
||||
return DOWNLOAD_WEIGHT + (extract_pct * EXTRACT_WEIGHT)
|
||||
# No extraction progress yet - return 70%
|
||||
return DOWNLOAD_WEIGHT
|
||||
|
||||
if self.total_size > 0:
|
||||
download_pct = min(1.0, self.download_current / self.total_size)
|
||||
return download_pct * DOWNLOAD_WEIGHT
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImagePullProgress:
|
||||
"""Track overall progress of pulling an image.
|
||||
|
||||
When manifest layer sizes are provided, uses size-weighted progress where
|
||||
each layer contributes proportionally to its size. This gives accurate
|
||||
progress based on actual bytes to download.
|
||||
|
||||
When manifest is not available, falls back to count-based progress where
|
||||
each layer contributes equally.
|
||||
|
||||
Layers that already exist locally are excluded from the progress calculation.
|
||||
"""
|
||||
|
||||
layers: dict[str, LayerProgress] = field(default_factory=dict)
|
||||
_last_reported_progress: float = field(default=0.0, repr=False)
|
||||
_seen_downloading: bool = field(default=False, repr=False)
|
||||
_manifest_layer_sizes: dict[str, int] = field(default_factory=dict, repr=False)
|
||||
_total_manifest_size: int = field(default=0, repr=False)
|
||||
|
||||
def set_manifest(self, manifest: ImageManifest) -> None:
|
||||
"""Set manifest layer sizes for accurate size-based progress.
|
||||
|
||||
Should be called before processing pull events.
|
||||
"""
|
||||
self._manifest_layer_sizes = dict(manifest.layers)
|
||||
self._total_manifest_size = manifest.total_size
|
||||
_LOGGER.debug(
|
||||
"Manifest set: %d layers, %d bytes total",
|
||||
len(self._manifest_layer_sizes),
|
||||
self._total_manifest_size,
|
||||
)
|
||||
|
||||
def get_or_create_layer(self, layer_id: str) -> LayerProgress:
|
||||
"""Get existing layer or create new one."""
|
||||
if layer_id not in self.layers:
|
||||
# If we have manifest sizes, pre-populate the layer's total_size
|
||||
manifest_size = self._manifest_layer_sizes.get(layer_id, 0)
|
||||
self.layers[layer_id] = LayerProgress(
|
||||
layer_id=layer_id, total_size=manifest_size
|
||||
)
|
||||
return self.layers[layer_id]
|
||||
|
||||
def process_event(self, entry: PullLogEntry) -> None:
|
||||
"""Process a pull log event and update layer state."""
|
||||
# Skip events without layer ID or status
|
||||
if not entry.id or not entry.status:
|
||||
return
|
||||
|
||||
# Skip metadata events that aren't layer-specific
|
||||
# "Pulling from X" has id=tag but isn't a layer
|
||||
if entry.status.startswith("Pulling from "):
|
||||
return
|
||||
|
||||
# Parse status to enum (returns None for unrecognized statuses)
|
||||
status = LayerPullStatus.from_status(entry.status)
|
||||
if status is None:
|
||||
return
|
||||
|
||||
layer = self.get_or_create_layer(entry.id)
|
||||
|
||||
# Handle "Already exists" - layer is locally available
|
||||
if status is LayerPullStatus.ALREADY_EXISTS:
|
||||
layer.already_exists = True
|
||||
layer.download_complete = True
|
||||
layer.extract_complete = True
|
||||
return
|
||||
|
||||
# Handle "Pulling fs layer" / "Waiting" - layer is being tracked
|
||||
if status in (LayerPullStatus.PULLING_FS_LAYER, LayerPullStatus.WAITING):
|
||||
return
|
||||
|
||||
# Handle "Downloading" - update download progress
|
||||
if status is LayerPullStatus.DOWNLOADING:
|
||||
# Mark that we've seen downloading - now we know layer count is complete
|
||||
self._seen_downloading = True
|
||||
if entry.progress_detail and entry.progress_detail.current is not None:
|
||||
layer.download_current = entry.progress_detail.current
|
||||
if entry.progress_detail and entry.progress_detail.total is not None:
|
||||
# Only set total_size if not already set or if this is larger
|
||||
# (handles case where total changes during download)
|
||||
layer.total_size = max(layer.total_size, entry.progress_detail.total)
|
||||
return
|
||||
|
||||
# Handle "Verifying Checksum" - download is essentially complete
|
||||
if status is LayerPullStatus.VERIFYING_CHECKSUM:
|
||||
if layer.total_size > 0:
|
||||
layer.download_current = layer.total_size
|
||||
return
|
||||
|
||||
# Handle "Download complete" - download phase done
|
||||
if status is LayerPullStatus.DOWNLOAD_COMPLETE:
|
||||
layer.download_complete = True
|
||||
if layer.total_size > 0:
|
||||
layer.download_current = layer.total_size
|
||||
elif layer.total_size == 0:
|
||||
# Small layer that skipped downloading phase
|
||||
# Set minimal size so it doesn't distort weighted average
|
||||
layer.total_size = 1
|
||||
layer.download_current = 1
|
||||
return
|
||||
|
||||
# Handle "Extracting" - extraction in progress
|
||||
if status is LayerPullStatus.EXTRACTING:
|
||||
# For overlay2: progressDetail has {current, total} in bytes
|
||||
# For containerd: progressDetail has {current, units: "s"} (time elapsed)
|
||||
# We can only use byte-based progress (overlay2)
|
||||
layer.download_complete = True
|
||||
if layer.total_size > 0:
|
||||
layer.download_current = layer.total_size
|
||||
|
||||
# Check if this is byte-based extraction progress (overlay2)
|
||||
# Overlay2 has {current, total} in bytes, no units field
|
||||
# Containerd has {current, units: "s"} which is useless for progress
|
||||
if (
|
||||
entry.progress_detail
|
||||
and entry.progress_detail.current is not None
|
||||
and entry.progress_detail.units is None
|
||||
):
|
||||
# Use layer's total_size from downloading phase (doesn't change)
|
||||
layer.extract_current = entry.progress_detail.current
|
||||
_LOGGER.debug(
|
||||
"Layer %s extracting: %d/%d (%.1f%%)",
|
||||
layer.layer_id,
|
||||
layer.extract_current,
|
||||
layer.total_size,
|
||||
(layer.extract_current / layer.total_size * 100)
|
||||
if layer.total_size > 0
|
||||
else 0,
|
||||
)
|
||||
return
|
||||
|
||||
# Handle "Pull complete" - layer is fully done
|
||||
if status is LayerPullStatus.PULL_COMPLETE:
|
||||
layer.download_complete = True
|
||||
layer.extract_complete = True
|
||||
if layer.total_size > 0:
|
||||
layer.download_current = layer.total_size
|
||||
return
|
||||
|
||||
# Handle "Retrying in N seconds" - reset download progress
|
||||
if status is LayerPullStatus.RETRYING:
|
||||
layer.download_current = 0
|
||||
layer.download_complete = False
|
||||
return
|
||||
|
||||
def calculate_progress(self) -> float:
|
||||
"""Calculate overall progress 0-100.
|
||||
|
||||
When manifest layer sizes are available, uses size-weighted progress
|
||||
where each layer contributes proportionally to its size.
|
||||
|
||||
When manifest is not available, falls back to count-based progress
|
||||
where each layer contributes equally.
|
||||
|
||||
Layers that already exist locally are excluded from the calculation.
|
||||
|
||||
Returns 0 until we've seen the first "Downloading" event, since Docker
|
||||
reports "Already exists" and "Pulling fs layer" events before we know
|
||||
the complete layer count.
|
||||
"""
|
||||
# Don't report progress until we've seen downloading start
|
||||
# This ensures we know the full layer count before calculating progress
|
||||
if not self._seen_downloading or not self.layers:
|
||||
return 0.0
|
||||
|
||||
# Only count layers that need pulling (exclude already_exists)
|
||||
layers_to_pull = [
|
||||
layer for layer in self.layers.values() if not layer.already_exists
|
||||
]
|
||||
|
||||
if not layers_to_pull:
|
||||
# All layers already exist, nothing to download
|
||||
return 100.0
|
||||
|
||||
# Use size-weighted progress if manifest sizes are available
|
||||
if self._manifest_layer_sizes:
|
||||
return min(100, self._calculate_size_weighted_progress(layers_to_pull))
|
||||
|
||||
# Fall back to count-based progress
|
||||
total_progress = sum(layer.calculate_progress() for layer in layers_to_pull)
|
||||
return min(100, total_progress / len(layers_to_pull))
|
||||
|
||||
def _calculate_size_weighted_progress(
|
||||
self, layers_to_pull: list[LayerProgress]
|
||||
) -> float:
|
||||
"""Calculate size-weighted progress.
|
||||
|
||||
Each layer contributes to progress proportionally to its size.
|
||||
Progress = sum(layer_progress * layer_size) / total_size
|
||||
"""
|
||||
# Calculate total size of layers that need pulling
|
||||
total_size = sum(layer.total_size for layer in layers_to_pull)
|
||||
|
||||
if total_size == 0:
|
||||
# No size info available, fall back to count-based
|
||||
total_progress = sum(layer.calculate_progress() for layer in layers_to_pull)
|
||||
return total_progress / len(layers_to_pull)
|
||||
|
||||
# Weight each layer's progress by its size
|
||||
weighted_progress = 0.0
|
||||
for layer in layers_to_pull:
|
||||
if layer.total_size > 0:
|
||||
layer_weight = layer.total_size / total_size
|
||||
weighted_progress += layer.calculate_progress() * layer_weight
|
||||
|
||||
return weighted_progress
|
||||
|
||||
def get_stage(self) -> str | None:
|
||||
"""Get current stage based on layer states."""
|
||||
if not self.layers:
|
||||
return None
|
||||
|
||||
# Check if any layer is still downloading
|
||||
for layer in self.layers.values():
|
||||
if layer.already_exists:
|
||||
continue
|
||||
if not layer.download_complete:
|
||||
return "Downloading"
|
||||
|
||||
# All downloads complete, check if extracting
|
||||
for layer in self.layers.values():
|
||||
if layer.already_exists:
|
||||
continue
|
||||
if not layer.extract_complete:
|
||||
return "Extracting"
|
||||
|
||||
# All done
|
||||
return "Pull complete"
|
||||
|
||||
def should_update_job(self, threshold: float = 1.0) -> tuple[bool, float]:
|
||||
"""Check if job should be updated based on progress change.
|
||||
|
||||
Returns (should_update, current_progress).
|
||||
Updates are triggered when progress changes by at least threshold%.
|
||||
Progress is guaranteed to only increase (monotonic).
|
||||
"""
|
||||
current_progress = self.calculate_progress()
|
||||
|
||||
# Ensure monotonic progress - never report a decrease
|
||||
# This can happen when new layers get size info and change the weighted average
|
||||
if current_progress < self._last_reported_progress:
|
||||
_LOGGER.debug(
|
||||
"Progress decreased from %.1f%% to %.1f%%, keeping last reported",
|
||||
self._last_reported_progress,
|
||||
current_progress,
|
||||
)
|
||||
return False, self._last_reported_progress
|
||||
|
||||
if current_progress >= self._last_reported_progress + threshold:
|
||||
_LOGGER.debug(
|
||||
"Progress update: %.1f%% -> %.1f%% (delta: %.1f%%)",
|
||||
self._last_reported_progress,
|
||||
current_progress,
|
||||
current_progress - self._last_reported_progress,
|
||||
)
|
||||
self._last_reported_progress = current_progress
|
||||
return True, current_progress
|
||||
|
||||
return False, self._last_reported_progress
|
||||
@@ -855,10 +855,6 @@ class DockerNotFound(DockerError):
|
||||
"""Docker object don't Exists."""
|
||||
|
||||
|
||||
class DockerLogOutOfOrder(DockerError):
|
||||
"""Raise when log from docker action was out of order."""
|
||||
|
||||
|
||||
class DockerNoSpaceOnDevice(DockerError):
|
||||
"""Raise if a docker pull fails due to available space."""
|
||||
|
||||
|
||||
@@ -306,6 +306,8 @@ async def test_api_progress_updates_home_assistant_update(
|
||||
and evt.args[0]["data"]["event"] == WSEvent.JOB
|
||||
and evt.args[0]["data"]["data"]["name"] == "home_assistant_core_update"
|
||||
]
|
||||
# Count-based progress: 2 layers need pulling (each worth 50%)
|
||||
# Layers that already exist are excluded from progress calculation
|
||||
assert events[:5] == [
|
||||
{
|
||||
"stage": None,
|
||||
@@ -319,36 +321,36 @@ async def test_api_progress_updates_home_assistant_update(
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 0.1,
|
||||
"progress": 9.2,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 1.7,
|
||||
"progress": 25.6,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 4.0,
|
||||
"progress": 35.4,
|
||||
"done": False,
|
||||
},
|
||||
]
|
||||
assert events[-5:] == [
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 95.5,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 96.9,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 98.2,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 98.3,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 99.3,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 100,
|
||||
|
||||
@@ -840,6 +840,8 @@ async def test_api_progress_updates_addon_install_update(
|
||||
and evt.args[0]["data"]["data"]["name"] == job_name
|
||||
and evt.args[0]["data"]["data"]["reference"] == addon_slug
|
||||
]
|
||||
# Count-based progress: 2 layers need pulling (each worth 50%)
|
||||
# Layers that already exist are excluded from progress calculation
|
||||
assert events[:4] == [
|
||||
{
|
||||
"stage": None,
|
||||
@@ -848,36 +850,36 @@ async def test_api_progress_updates_addon_install_update(
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 0.1,
|
||||
"progress": 9.2,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 1.7,
|
||||
"progress": 25.6,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 4.0,
|
||||
"progress": 35.4,
|
||||
"done": False,
|
||||
},
|
||||
]
|
||||
assert events[-5:] == [
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 95.5,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 96.9,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 98.2,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 98.3,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 99.3,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 100,
|
||||
|
||||
@@ -360,6 +360,8 @@ async def test_api_progress_updates_supervisor_update(
|
||||
and evt.args[0]["data"]["event"] == WSEvent.JOB
|
||||
and evt.args[0]["data"]["data"]["name"] == "supervisor_update"
|
||||
]
|
||||
# Count-based progress: 2 layers need pulling (each worth 50%)
|
||||
# Layers that already exist are excluded from progress calculation
|
||||
assert events[:4] == [
|
||||
{
|
||||
"stage": None,
|
||||
@@ -368,36 +370,36 @@ async def test_api_progress_updates_supervisor_update(
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 0.1,
|
||||
"progress": 9.2,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 1.7,
|
||||
"progress": 25.6,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 4.0,
|
||||
"progress": 35.4,
|
||||
"done": False,
|
||||
},
|
||||
]
|
||||
assert events[-5:] == [
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 95.5,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 96.9,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 98.2,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 98.3,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 99.3,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 100,
|
||||
|
||||
@@ -202,6 +202,9 @@ async def docker() -> DockerAPI:
|
||||
docker_obj.info.storage = "overlay2"
|
||||
docker_obj.info.version = AwesomeVersion("1.0.0")
|
||||
|
||||
# Mock manifest fetcher to return None (falls back to count-based progress)
|
||||
docker_obj._manifest_fetcher.get_manifest = AsyncMock(return_value=None)
|
||||
|
||||
yield docker_obj
|
||||
|
||||
|
||||
|
||||
@@ -704,11 +704,18 @@ async def test_install_progress_handles_layers_skipping_download(
|
||||
await install_task
|
||||
await event.wait()
|
||||
|
||||
# First update from layer download should have rather low progress ((260937/25445459) / 2 ~ 0.5%)
|
||||
assert install_job_snapshots[0]["progress"] < 1
|
||||
# With the new progress calculation approach:
|
||||
# - Progress is weighted by layer size
|
||||
# - Small layers that skip downloading get minimal size (1 byte)
|
||||
# - Progress should increase monotonically
|
||||
assert len(install_job_snapshots) > 0
|
||||
|
||||
# Total 8 events should lead to a progress update on the install job
|
||||
assert len(install_job_snapshots) == 8
|
||||
# Verify progress is monotonically increasing (or stable)
|
||||
for i in range(1, len(install_job_snapshots)):
|
||||
assert (
|
||||
install_job_snapshots[i]["progress"]
|
||||
>= install_job_snapshots[i - 1]["progress"]
|
||||
)
|
||||
|
||||
# Job should complete successfully
|
||||
assert job.done is True
|
||||
@@ -842,24 +849,24 @@ async def test_install_progress_containerd_snapshot(
|
||||
}
|
||||
|
||||
assert [c.args[0] for c in ha_ws_client.async_send_command.call_args_list] == [
|
||||
# During downloading we get continuous progress updates from download status
|
||||
# Count-based progress: 2 layers, each = 50%. Download = 0-35%, Extract = 35-50%
|
||||
job_event(0),
|
||||
job_event(1.7),
|
||||
job_event(3.4),
|
||||
job_event(8.5),
|
||||
job_event(8.4),
|
||||
job_event(10.2),
|
||||
job_event(15.3),
|
||||
job_event(18.8),
|
||||
job_event(29.0),
|
||||
job_event(35.8),
|
||||
job_event(42.6),
|
||||
job_event(49.5),
|
||||
job_event(56.0),
|
||||
job_event(62.8),
|
||||
# Downloading phase is considered 70% of total. After we only get one update
|
||||
# per image downloaded when extraction is finished. It uses the total size
|
||||
# received during downloading to determine percent complete then.
|
||||
job_event(15.2),
|
||||
job_event(18.7),
|
||||
job_event(28.8),
|
||||
job_event(35.7),
|
||||
job_event(42.4),
|
||||
job_event(49.3),
|
||||
job_event(55.8),
|
||||
job_event(62.7),
|
||||
# Downloading phase is considered 70% of layer's progress.
|
||||
# After download complete, extraction takes remaining 30% per layer.
|
||||
job_event(70.0),
|
||||
job_event(84.8),
|
||||
job_event(85.0),
|
||||
job_event(100),
|
||||
job_event(100, True),
|
||||
]
|
||||
|
||||
143
tests/docker/test_manifest.py
Normal file
143
tests/docker/test_manifest.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Tests for registry manifest fetcher."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from supervisor.coresys import CoreSys
|
||||
from supervisor.docker.manifest import (
|
||||
DOCKER_HUB,
|
||||
ImageManifest,
|
||||
RegistryManifestFetcher,
|
||||
parse_image_reference,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_image_reference_ghcr_io():
|
||||
"""Test parsing ghcr.io image."""
|
||||
registry, repo, tag = parse_image_reference(
|
||||
"ghcr.io/home-assistant/home-assistant", "2025.1.0"
|
||||
)
|
||||
assert registry == "ghcr.io"
|
||||
assert repo == "home-assistant/home-assistant"
|
||||
assert tag == "2025.1.0"
|
||||
|
||||
|
||||
def test_parse_image_reference_docker_hub_with_org():
|
||||
"""Test parsing Docker Hub image with organization."""
|
||||
registry, repo, tag = parse_image_reference(
|
||||
"homeassistant/home-assistant", "latest"
|
||||
)
|
||||
assert registry == DOCKER_HUB
|
||||
assert repo == "homeassistant/home-assistant"
|
||||
assert tag == "latest"
|
||||
|
||||
|
||||
def test_parse_image_reference_docker_hub_official_image():
|
||||
"""Test parsing Docker Hub official image (no org)."""
|
||||
registry, repo, tag = parse_image_reference("alpine", "3.18")
|
||||
assert registry == DOCKER_HUB
|
||||
assert repo == "library/alpine"
|
||||
assert tag == "3.18"
|
||||
|
||||
|
||||
def test_parse_image_reference_gcr_io():
|
||||
"""Test parsing gcr.io image."""
|
||||
registry, repo, tag = parse_image_reference("gcr.io/project/image", "v1")
|
||||
assert registry == "gcr.io"
|
||||
assert repo == "project/image"
|
||||
assert tag == "v1"
|
||||
|
||||
|
||||
def test_image_manifest_layer_count():
|
||||
"""Test ImageManifest layer_count property."""
|
||||
manifest = ImageManifest(
|
||||
digest="sha256:abc",
|
||||
total_size=1000,
|
||||
layers={"layer1": 500, "layer2": 500},
|
||||
)
|
||||
assert manifest.layer_count == 2
|
||||
|
||||
|
||||
async def test_get_manifest_success(coresys: CoreSys, websession: MagicMock):
|
||||
"""Test successful manifest fetch by mocking internal methods."""
|
||||
fetcher = RegistryManifestFetcher(coresys)
|
||||
manifest_data = {
|
||||
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
"config": {"digest": "sha256:abc123"},
|
||||
"layers": [
|
||||
{"digest": "sha256:layer1abc123def456789012", "size": 1000},
|
||||
{"digest": "sha256:layer2def456abc789012345", "size": 2000},
|
||||
],
|
||||
}
|
||||
|
||||
# Mock the internal methods
|
||||
with (
|
||||
patch.object(
|
||||
fetcher, "_get_auth_token", new=AsyncMock(return_value="test-token")
|
||||
),
|
||||
patch.object(
|
||||
fetcher, "_fetch_manifest", new=AsyncMock(return_value=manifest_data)
|
||||
),
|
||||
):
|
||||
result = await fetcher.get_manifest(
|
||||
"test.io/org/image", "v1.0", platform="linux/amd64"
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.total_size == 3000
|
||||
assert result.layer_count == 2
|
||||
# First 12 chars after sha256:
|
||||
assert "layer1abc123" in result.layers
|
||||
assert result.layers["layer1abc123"] == 1000
|
||||
|
||||
|
||||
async def test_get_manifest_returns_none_on_failure(
|
||||
coresys: CoreSys, websession: MagicMock
|
||||
):
|
||||
"""Test that get_manifest returns None on failure."""
|
||||
fetcher = RegistryManifestFetcher(coresys)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
fetcher, "_get_auth_token", new=AsyncMock(return_value="test-token")
|
||||
),
|
||||
patch.object(fetcher, "_fetch_manifest", new=AsyncMock(return_value=None)),
|
||||
):
|
||||
result = await fetcher.get_manifest(
|
||||
"test.io/org/image", "v1.0", platform="linux/amd64"
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_credentials_docker_hub(coresys: CoreSys, websession: MagicMock):
|
||||
"""Test getting Docker Hub credentials."""
|
||||
coresys.docker.config._data["registries"] = { # pylint: disable=protected-access
|
||||
"docker.io": {"username": "user", "password": "pass"}
|
||||
}
|
||||
fetcher = RegistryManifestFetcher(coresys)
|
||||
|
||||
creds = fetcher._get_credentials(DOCKER_HUB) # pylint: disable=protected-access
|
||||
|
||||
assert creds == ("user", "pass")
|
||||
|
||||
|
||||
def test_get_credentials_custom_registry(coresys: CoreSys, websession: MagicMock):
|
||||
"""Test getting credentials for custom registry."""
|
||||
coresys.docker.config._data["registries"] = { # pylint: disable=protected-access
|
||||
"ghcr.io": {"username": "user", "password": "token"}
|
||||
}
|
||||
fetcher = RegistryManifestFetcher(coresys)
|
||||
|
||||
creds = fetcher._get_credentials("ghcr.io") # pylint: disable=protected-access
|
||||
|
||||
assert creds == ("user", "token")
|
||||
|
||||
|
||||
def test_get_credentials_not_found(coresys: CoreSys, websession: MagicMock):
|
||||
"""Test no credentials found."""
|
||||
coresys.docker.config._data["registries"] = {} # pylint: disable=protected-access
|
||||
fetcher = RegistryManifestFetcher(coresys)
|
||||
|
||||
creds = fetcher._get_credentials("unknown.io") # pylint: disable=protected-access
|
||||
|
||||
assert creds is None
|
||||
1002
tests/docker/test_pull_progress.py
Normal file
1002
tests/docker/test_pull_progress.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user