mirror of
https://github.com/home-assistant/core.git
synced 2026-02-15 07:36:16 +00:00
Use service helper to extract swiss public transport config entry (#162810)
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.const import ATTR_CONFIG_ENTRY_ID
|
||||
from homeassistant.core import (
|
||||
HomeAssistant,
|
||||
@@ -11,7 +10,8 @@ from homeassistant.core import (
|
||||
SupportsResponse,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import service
|
||||
from homeassistant.helpers.selector import (
|
||||
NumberSelector,
|
||||
NumberSelectorConfig,
|
||||
@@ -40,30 +40,13 @@ SERVICE_FETCH_CONNECTIONS_SCHEMA = vol.Schema(
|
||||
)
|
||||
|
||||
|
||||
def _async_get_entry(
|
||||
hass: HomeAssistant, config_entry_id: str
|
||||
) -> SwissPublicTransportConfigEntry:
|
||||
"""Get the Swiss public transport config entry."""
|
||||
if not (entry := hass.config_entries.async_get_entry(config_entry_id)):
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="config_entry_not_found",
|
||||
translation_placeholders={"target": config_entry_id},
|
||||
)
|
||||
if entry.state is not ConfigEntryState.LOADED:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="not_loaded",
|
||||
translation_placeholders={"target": entry.title},
|
||||
)
|
||||
return entry
|
||||
|
||||
|
||||
async def _async_fetch_connections(
|
||||
call: ServiceCall,
|
||||
) -> ServiceResponse:
|
||||
"""Fetch a set of connections."""
|
||||
config_entry = _async_get_entry(call.hass, call.data[ATTR_CONFIG_ENTRY_ID])
|
||||
config_entry: SwissPublicTransportConfigEntry = service.async_get_config_entry(
|
||||
call.hass, DOMAIN, call.data[ATTR_CONFIG_ENTRY_ID]
|
||||
)
|
||||
|
||||
limit = call.data.get(ATTR_LIMIT) or CONNECTIONS_COUNT
|
||||
try:
|
||||
|
||||
@@ -85,15 +85,9 @@
|
||||
"cannot_connect": {
|
||||
"message": "Cannot connect to server.\n{error}"
|
||||
},
|
||||
"config_entry_not_found": {
|
||||
"message": "Swiss public transport integration instance \"{target}\" not found."
|
||||
},
|
||||
"invalid_data": {
|
||||
"message": "Setup failed for entry {config_title} with invalid data, check at the [stationboard]({stationboard_url}) if your station names are valid.\n{error}"
|
||||
},
|
||||
"not_loaded": {
|
||||
"message": "{target} is not loaded."
|
||||
},
|
||||
"request_timeout": {
|
||||
"message": "Timeout while connecting for entry {config_title}.\n{error}"
|
||||
}
|
||||
|
||||
@@ -201,7 +201,7 @@ async def test_service_call_load_unload(
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with pytest.raises(
|
||||
ServiceValidationError, match=f"{config_entry.title} is not loaded"
|
||||
ServiceValidationError, match="service_config_entry_not_loaded"
|
||||
):
|
||||
await hass.services.async_call(
|
||||
domain=DOMAIN,
|
||||
@@ -215,7 +215,7 @@ async def test_service_call_load_unload(
|
||||
|
||||
with pytest.raises(
|
||||
ServiceValidationError,
|
||||
match=f'Swiss public transport integration instance "{bad_entry_id}" not found',
|
||||
match="service_config_entry_not_found",
|
||||
):
|
||||
await hass.services.async_call(
|
||||
domain=DOMAIN,
|
||||
|
||||
Reference in New Issue
Block a user