mirror of
https://github.com/home-assistant/core.git
synced 2026-02-14 23:28:42 +00:00
Use service helper to extract transmission config entry (#162814)
This commit is contained in:
@@ -3,16 +3,15 @@
|
||||
from enum import StrEnum
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from transmission_rpc import Torrent
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.const import CONF_ID
|
||||
from homeassistant.core import HomeAssistant, ServiceCall, SupportsResponse, callback
|
||||
from homeassistant.exceptions import ServiceValidationError
|
||||
from homeassistant.helpers import config_validation as cv, selector
|
||||
from homeassistant.helpers import config_validation as cv, selector, service
|
||||
|
||||
from .const import (
|
||||
ATTR_DELETE_DATA,
|
||||
@@ -31,7 +30,7 @@ from .const import (
|
||||
SERVICE_START_TORRENT,
|
||||
SERVICE_STOP_TORRENT,
|
||||
)
|
||||
from .coordinator import TransmissionDataUpdateCoordinator
|
||||
from .coordinator import TransmissionConfigEntry, TransmissionDataUpdateCoordinator
|
||||
from .helpers import filter_torrents, format_torrents
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -101,41 +100,31 @@ def _get_coordinator_from_service_data(
|
||||
call: ServiceCall,
|
||||
) -> TransmissionDataUpdateCoordinator:
|
||||
"""Return coordinator for entry id."""
|
||||
config_entry_id: str = call.data[CONF_ENTRY_ID]
|
||||
if not (entry := call.hass.config_entries.async_get_entry(config_entry_id)):
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="integration_not_found",
|
||||
translation_placeholders={"target": DOMAIN},
|
||||
)
|
||||
if entry.state is not ConfigEntryState.LOADED:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="not_loaded",
|
||||
translation_placeholders={"target": entry.title},
|
||||
)
|
||||
return cast(TransmissionDataUpdateCoordinator, entry.runtime_data)
|
||||
entry: TransmissionConfigEntry = service.async_get_config_entry(
|
||||
call.hass, DOMAIN, call.data[CONF_ENTRY_ID]
|
||||
)
|
||||
return entry.runtime_data
|
||||
|
||||
|
||||
async def _async_add_torrent(service: ServiceCall) -> None:
|
||||
async def _async_add_torrent(call: ServiceCall) -> None:
|
||||
"""Add new torrent to download."""
|
||||
coordinator = _get_coordinator_from_service_data(service)
|
||||
torrent: str = service.data[ATTR_TORRENT]
|
||||
download_path: str | None = service.data.get(ATTR_DOWNLOAD_PATH)
|
||||
coordinator = _get_coordinator_from_service_data(call)
|
||||
torrent: str = call.data[ATTR_TORRENT]
|
||||
download_path: str | None = call.data.get(ATTR_DOWNLOAD_PATH)
|
||||
labels: list[str] | None = (
|
||||
service.data[ATTR_LABELS].split(",") if ATTR_LABELS in service.data else None
|
||||
call.data[ATTR_LABELS].split(",") if ATTR_LABELS in call.data else None
|
||||
)
|
||||
|
||||
if not (
|
||||
torrent.startswith(("http", "ftp:", "magnet:"))
|
||||
or service.hass.config.is_allowed_path(torrent)
|
||||
or call.hass.config.is_allowed_path(torrent)
|
||||
):
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="could_not_add_torrent",
|
||||
)
|
||||
|
||||
await service.hass.async_add_executor_job(
|
||||
await call.hass.async_add_executor_job(
|
||||
partial(
|
||||
coordinator.api.add_torrent,
|
||||
torrent,
|
||||
@@ -146,17 +135,17 @@ async def _async_add_torrent(service: ServiceCall) -> None:
|
||||
await coordinator.async_request_refresh()
|
||||
|
||||
|
||||
async def _async_get_torrents(service: ServiceCall) -> dict[str, Any] | None:
|
||||
async def _async_get_torrents(call: ServiceCall) -> dict[str, Any] | None:
|
||||
"""Get torrents."""
|
||||
coordinator = _get_coordinator_from_service_data(service)
|
||||
torrent_filter: str = service.data[ATTR_TORRENT_FILTER]
|
||||
coordinator = _get_coordinator_from_service_data(call)
|
||||
torrent_filter: str = call.data[ATTR_TORRENT_FILTER]
|
||||
|
||||
def get_filtered_torrents() -> list[Torrent]:
|
||||
"""Filter torrents based on the filter provided."""
|
||||
all_torrents = coordinator.api.get_torrents()
|
||||
return filter_torrents(all_torrents, FILTER_MODES[torrent_filter])
|
||||
|
||||
torrents = await service.hass.async_add_executor_job(get_filtered_torrents)
|
||||
torrents = await call.hass.async_add_executor_job(get_filtered_torrents)
|
||||
|
||||
info = format_torrents(torrents)
|
||||
return {
|
||||
@@ -164,28 +153,28 @@ async def _async_get_torrents(service: ServiceCall) -> dict[str, Any] | None:
|
||||
}
|
||||
|
||||
|
||||
async def _async_start_torrent(service: ServiceCall) -> None:
|
||||
async def _async_start_torrent(call: ServiceCall) -> None:
|
||||
"""Start torrent."""
|
||||
coordinator = _get_coordinator_from_service_data(service)
|
||||
torrent_id = service.data[CONF_ID]
|
||||
await service.hass.async_add_executor_job(coordinator.api.start_torrent, torrent_id)
|
||||
coordinator = _get_coordinator_from_service_data(call)
|
||||
torrent_id = call.data[CONF_ID]
|
||||
await call.hass.async_add_executor_job(coordinator.api.start_torrent, torrent_id)
|
||||
await coordinator.async_request_refresh()
|
||||
|
||||
|
||||
async def _async_stop_torrent(service: ServiceCall) -> None:
|
||||
async def _async_stop_torrent(call: ServiceCall) -> None:
|
||||
"""Stop torrent."""
|
||||
coordinator = _get_coordinator_from_service_data(service)
|
||||
torrent_id = service.data[CONF_ID]
|
||||
await service.hass.async_add_executor_job(coordinator.api.stop_torrent, torrent_id)
|
||||
coordinator = _get_coordinator_from_service_data(call)
|
||||
torrent_id = call.data[CONF_ID]
|
||||
await call.hass.async_add_executor_job(coordinator.api.stop_torrent, torrent_id)
|
||||
await coordinator.async_request_refresh()
|
||||
|
||||
|
||||
async def _async_remove_torrent(service: ServiceCall) -> None:
|
||||
async def _async_remove_torrent(call: ServiceCall) -> None:
|
||||
"""Remove torrent."""
|
||||
coordinator = _get_coordinator_from_service_data(service)
|
||||
torrent_id = service.data[CONF_ID]
|
||||
delete_data = service.data[ATTR_DELETE_DATA]
|
||||
await service.hass.async_add_executor_job(
|
||||
coordinator = _get_coordinator_from_service_data(call)
|
||||
torrent_id = call.data[CONF_ID]
|
||||
delete_data = call.data[ATTR_DELETE_DATA]
|
||||
await call.hass.async_add_executor_job(
|
||||
partial(coordinator.api.remove_torrent, torrent_id, delete_data=delete_data)
|
||||
)
|
||||
await coordinator.async_request_refresh()
|
||||
|
||||
@@ -90,12 +90,6 @@
|
||||
"exceptions": {
|
||||
"could_not_add_torrent": {
|
||||
"message": "Could not add torrent: unsupported type or no permission."
|
||||
},
|
||||
"integration_not_found": {
|
||||
"message": "Integration \"{target}\" not found in registry."
|
||||
},
|
||||
"not_loaded": {
|
||||
"message": "{target} is not loaded."
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
|
||||
@@ -59,9 +59,7 @@ async def test_service_integration_not_found(
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with pytest.raises(
|
||||
ServiceValidationError, match='Integration "transmission" not found'
|
||||
):
|
||||
with pytest.raises(ServiceValidationError, match="service_config_entry_not_found"):
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_ADD_TORRENT,
|
||||
|
||||
Reference in New Issue
Block a user