Skip to content

Commit

Permalink
Add time and offset config to Swiss public transport connections (hom…
Browse files Browse the repository at this point in the history
…e-assistant#120357)

* add time and offset config for connections

* split the config flow

* fix arrival config

* add time_mode data description

* use delta as dict instead of string

* simplify the config_flow

* improve descriptions of config_flow

* improve config flow

* remove obsolete string

* switch priority of the config options

* improvements
  • Loading branch information
miaucl authored Nov 27, 2024
1 parent 345c1fe commit 284fe17
Show file tree
Hide file tree
Showing 8 changed files with 437 additions and 74 deletions.
37 changes: 29 additions & 8 deletions homeassistant/components/swiss_public_transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,22 @@
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.typing import ConfigType

from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, PLACEHOLDERS
from .const import (
CONF_DESTINATION,
CONF_START,
CONF_TIME_FIXED,
CONF_TIME_OFFSET,
CONF_TIME_STATION,
CONF_VIA,
DEFAULT_TIME_STATION,
DOMAIN,
PLACEHOLDERS,
)
from .coordinator import (
SwissPublicTransportConfigEntry,
SwissPublicTransportDataUpdateCoordinator,
)
from .helper import unique_id_from_config
from .helper import offset_opendata, unique_id_from_config
from .services import setup_services

_LOGGER = logging.getLogger(__name__)
Expand All @@ -50,8 +60,19 @@ async def async_setup_entry(
start = config[CONF_START]
destination = config[CONF_DESTINATION]

time_offset: dict[str, int] | None = config.get(CONF_TIME_OFFSET)

session = async_get_clientsession(hass)
opendata = OpendataTransport(start, destination, session, via=config.get(CONF_VIA))
opendata = OpendataTransport(
start,
destination,
session,
via=config.get(CONF_VIA),
time=config.get(CONF_TIME_FIXED),
isArrivalTime=config.get(CONF_TIME_STATION, DEFAULT_TIME_STATION) == "arrival",
)
if time_offset:
offset_opendata(opendata, time_offset)

try:
await opendata.async_get_data()
Expand All @@ -75,7 +96,7 @@ async def async_setup_entry(
},
) from e

coordinator = SwissPublicTransportDataUpdateCoordinator(hass, opendata)
coordinator = SwissPublicTransportDataUpdateCoordinator(hass, opendata, time_offset)
await coordinator.async_config_entry_first_refresh()
entry.runtime_data = coordinator

Expand All @@ -96,7 +117,7 @@ async def async_migrate_entry(
"""Migrate config entry."""
_LOGGER.debug("Migrating from version %s", config_entry.version)

if config_entry.version > 2:
if config_entry.version > 3:
# This means the user has downgraded from a future version
return False

Expand Down Expand Up @@ -131,9 +152,9 @@ async def async_migrate_entry(
config_entry, unique_id=new_unique_id, minor_version=2
)

if config_entry.version < 2:
# Via stations now available, which are not backwards compatible if used, changes unique id
hass.config_entries.async_update_entry(config_entry, version=2, minor_version=1)
if config_entry.version < 3:
# Via stations and time/offset settings now available, which are not backwards compatible if used, changes unique id
hass.config_entries.async_update_entry(config_entry, version=3, minor_version=1)

_LOGGER.debug(
"Migration to version %s.%s successful",
Expand Down
159 changes: 134 additions & 25 deletions homeassistant/components/swiss_public_transport/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,35 @@
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.selector import (
DurationSelector,
SelectSelector,
SelectSelectorConfig,
SelectSelectorMode,
TextSelector,
TextSelectorConfig,
TextSelectorType,
TimeSelector,
)

from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, MAX_VIA, PLACEHOLDERS
from .helper import unique_id_from_config
from .const import (
CONF_DESTINATION,
CONF_START,
CONF_TIME_FIXED,
CONF_TIME_MODE,
CONF_TIME_OFFSET,
CONF_TIME_STATION,
CONF_VIA,
DEFAULT_TIME_MODE,
DEFAULT_TIME_STATION,
DOMAIN,
IS_ARRIVAL_OPTIONS,
MAX_VIA,
PLACEHOLDERS,
TIME_MODE_OPTIONS,
)
from .helper import offset_opendata, unique_id_from_config

DATA_SCHEMA = vol.Schema(
USER_DATA_SCHEMA = vol.Schema(
{
vol.Required(CONF_START): cv.string,
vol.Optional(CONF_VIA): TextSelector(
Expand All @@ -32,56 +52,145 @@
),
),
vol.Required(CONF_DESTINATION): cv.string,
vol.Optional(CONF_TIME_MODE, default=DEFAULT_TIME_MODE): SelectSelector(
SelectSelectorConfig(
options=TIME_MODE_OPTIONS,
mode=SelectSelectorMode.DROPDOWN,
translation_key="time_mode",
),
),
vol.Optional(CONF_TIME_STATION, default=DEFAULT_TIME_STATION): SelectSelector(
SelectSelectorConfig(
options=IS_ARRIVAL_OPTIONS,
mode=SelectSelectorMode.DROPDOWN,
translation_key="time_station",
),
),
}
)
ADVANCED_TIME_DATA_SCHEMA = {vol.Optional(CONF_TIME_FIXED): TimeSelector()}
ADVANCED_TIME_OFFSET_DATA_SCHEMA = {vol.Optional(CONF_TIME_OFFSET): DurationSelector()}


_LOGGER = logging.getLogger(__name__)


class SwissPublicTransportConfigFlow(ConfigFlow, domain=DOMAIN):
"""Swiss public transport config flow."""

VERSION = 2
VERSION = 3
MINOR_VERSION = 1

user_input: dict[str, Any]

async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Async user step to set up the connection."""
errors: dict[str, str] = {}
if user_input is not None:
unique_id = unique_id_from_config(user_input)
await self.async_set_unique_id(unique_id)
self._abort_if_unique_id_configured()

if CONF_VIA in user_input and len(user_input[CONF_VIA]) > MAX_VIA:
errors["base"] = "too_many_via_stations"
else:
session = async_get_clientsession(self.hass)
opendata = OpendataTransport(
user_input[CONF_START],
user_input[CONF_DESTINATION],
session,
via=user_input.get(CONF_VIA),
)
try:
await opendata.async_get_data()
except OpendataTransportConnectionError:
errors["base"] = "cannot_connect"
except OpendataTransportError:
errors["base"] = "bad_config"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unknown error")
errors["base"] = "unknown"
err = await self.fetch_connections(user_input)
if err:
errors["base"] = err
else:
self.user_input = user_input
if user_input[CONF_TIME_MODE] == "fixed":
return await self.async_step_time_fixed()
if user_input[CONF_TIME_MODE] == "offset":
return await self.async_step_time_offset()

unique_id = unique_id_from_config(user_input)
await self.async_set_unique_id(unique_id)
self._abort_if_unique_id_configured()
return self.async_create_entry(
title=unique_id,
data=user_input,
)

return self.async_show_form(
step_id="user",
data_schema=DATA_SCHEMA,
data_schema=self.add_suggested_values_to_schema(
data_schema=USER_DATA_SCHEMA,
suggested_values=user_input,
),
errors=errors,
description_placeholders=PLACEHOLDERS,
)

async def async_step_time_fixed(
self, time_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Async time step to set up the connection."""
return await self._async_step_time_mode(
CONF_TIME_FIXED, vol.Schema(ADVANCED_TIME_DATA_SCHEMA), time_input
)

async def async_step_time_offset(
self, time_offset_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Async time offset step to set up the connection."""
return await self._async_step_time_mode(
CONF_TIME_OFFSET,
vol.Schema(ADVANCED_TIME_OFFSET_DATA_SCHEMA),
time_offset_input,
)

async def _async_step_time_mode(
self,
step_id: str,
time_mode_schema: vol.Schema,
time_mode_input: dict[str, Any] | None = None,
) -> ConfigFlowResult:
"""Async time mode step to set up the connection."""
errors: dict[str, str] = {}
if time_mode_input is not None:
unique_id = unique_id_from_config({**self.user_input, **time_mode_input})
await self.async_set_unique_id(unique_id)
self._abort_if_unique_id_configured()

err = await self.fetch_connections(
{**self.user_input, **time_mode_input},
time_mode_input.get(CONF_TIME_OFFSET),
)
if err:
errors["base"] = err
else:
return self.async_create_entry(
title=unique_id,
data={**self.user_input, **time_mode_input},
)

return self.async_show_form(
step_id=step_id,
data_schema=time_mode_schema,
errors=errors,
description_placeholders=PLACEHOLDERS,
)

async def fetch_connections(
self, input: dict[str, Any], time_offset: dict[str, int] | None = None
) -> str | None:
"""Fetch the connections and advancedly return an error."""
try:
session = async_get_clientsession(self.hass)
opendata = OpendataTransport(
input[CONF_START],
input[CONF_DESTINATION],
session,
via=input.get(CONF_VIA),
time=input.get(CONF_TIME_FIXED),
)
if time_offset:
offset_opendata(opendata, time_offset)
await opendata.async_get_data()
except OpendataTransportConnectionError:
return "cannot_connect"
except OpendataTransportError:
return "bad_config"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unknown error")
return "unknown"
return None
8 changes: 8 additions & 0 deletions homeassistant/components/swiss_public_transport/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,21 @@
CONF_DESTINATION: Final = "to"
CONF_START: Final = "from"
CONF_VIA: Final = "via"
CONF_TIME_STATION: Final = "time_station"
CONF_TIME_MODE: Final = "time_mode"
CONF_TIME_FIXED: Final = "time_fixed"
CONF_TIME_OFFSET: Final = "time_offset"

DEFAULT_NAME = "Next Destination"
DEFAULT_UPDATE_TIME = 90
DEFAULT_TIME_STATION = "departure"
DEFAULT_TIME_MODE = "now"

MAX_VIA = 5
CONNECTIONS_COUNT = 3
CONNECTIONS_MAX = 15
IS_ARRIVAL_OPTIONS = ["departure", "arrival"]
TIME_MODE_OPTIONS = ["now", "fixed", "offset"]


PLACEHOLDERS = {
Expand Down
12 changes: 11 additions & 1 deletion homeassistant/components/swiss_public_transport/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from homeassistant.util.json import JsonValueType

from .const import CONNECTIONS_COUNT, DEFAULT_UPDATE_TIME, DOMAIN
from .helper import offset_opendata

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,7 +58,12 @@ class SwissPublicTransportDataUpdateCoordinator(

config_entry: SwissPublicTransportConfigEntry

def __init__(self, hass: HomeAssistant, opendata: OpendataTransport) -> None:
def __init__(
self,
hass: HomeAssistant,
opendata: OpendataTransport,
time_offset: dict[str, int] | None,
) -> None:
"""Initialize the SwissPublicTransport data coordinator."""
super().__init__(
hass,
Expand All @@ -66,6 +72,7 @@ def __init__(self, hass: HomeAssistant, opendata: OpendataTransport) -> None:
update_interval=timedelta(seconds=DEFAULT_UPDATE_TIME),
)
self._opendata = opendata
self._time_offset = time_offset

def remaining_time(self, departure) -> timedelta | None:
"""Calculate the remaining time for the departure."""
Expand All @@ -81,6 +88,9 @@ async def _async_update_data(self) -> list[DataConnection]:
async def fetch_connections(self, limit: int) -> list[DataConnection]:
"""Fetch connections using the opendata api."""
self._opendata.limit = limit
if self._time_offset:
offset_opendata(self._opendata, self._time_offset)

try:
await self._opendata.async_get_data()
except OpendataTransportConnectionError as e:
Expand Down
Loading

0 comments on commit 284fe17

Please sign in to comment.