Skip to content

Commit

Permalink
Change Trafikverket Train to use station signatures (home-assistant#1…
Browse files Browse the repository at this point in the history
…31416)

Co-authored-by: Robert Resch <robert@resch.dev>
  • Loading branch information
gjohansson-ST and edenhaus authored Jan 13, 2025
1 parent 1575486 commit 4709a31
Show file tree
Hide file tree
Showing 7 changed files with 511 additions and 215 deletions.
62 changes: 57 additions & 5 deletions homeassistant/components/trafikverket_train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,21 @@

import logging

from pytrafikverket import (
InvalidAuthentication,
NoTrainStationFound,
TrafikverketTrain,
UnknownError,
)

from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession

from .const import PLATFORMS
from .const import CONF_FROM, CONF_TO, PLATFORMS
from .coordinator import TVDataUpdateCoordinator

TVTrainConfigEntry = ConfigEntry[TVDataUpdateCoordinator]
Expand Down Expand Up @@ -52,13 +62,55 @@ async def async_migrate_entry(hass: HomeAssistant, entry: TVTrainConfigEntry) ->
"""Migrate config entry."""
_LOGGER.debug("Migrating from version %s", entry.version)

if entry.version > 1:
if entry.version > 2:
# This means the user has downgraded from a future version
return False

if entry.version == 1 and entry.minor_version == 1:
# Remove unique id
hass.config_entries.async_update_entry(entry, unique_id=None, minor_version=2)
if entry.version == 1:
if entry.minor_version == 1:
# Remove unique id
hass.config_entries.async_update_entry(
entry, unique_id=None, minor_version=2
)

# Change from station names to station signatures
try:
web_session = async_get_clientsession(hass)
train_api = TrafikverketTrain(web_session, entry.data[CONF_API_KEY])
from_stations = await train_api.async_search_train_stations(
entry.data[CONF_FROM]
)
to_stations = await train_api.async_search_train_stations(
entry.data[CONF_TO]
)
except InvalidAuthentication as error:
raise ConfigEntryAuthFailed from error
except NoTrainStationFound as error:
_LOGGER.error(
"Migration failed as no train station found with provided name %s",
str(error),
)
return False
except UnknownError as error:
_LOGGER.error("Unknown error occurred during validation %s", str(error))
return False
except Exception as error: # noqa: BLE001
_LOGGER.error("Unknown exception occurred during validation %s", str(error))
return False

if len(from_stations) > 1 or len(to_stations) > 1:
_LOGGER.error(
"Migration failed as more than one station found with provided name"
)
return False

new_data = entry.data.copy()
new_data[CONF_FROM] = from_stations[0].signature
new_data[CONF_TO] = to_stations[0].signature

hass.config_entries.async_update_entry(
entry, data=new_data, version=2, minor_version=1
)

_LOGGER.debug(
"Migration to version %s.%s successful",
Expand Down
201 changes: 123 additions & 78 deletions homeassistant/components/trafikverket_train/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
from __future__ import annotations

from collections.abc import Mapping
from datetime import datetime
import logging
from typing import Any

from pytrafikverket import TrafikverketTrain
from pytrafikverket.exceptions import (
from pytrafikverket import (
InvalidAuthentication,
MultipleTrainStationsFound,
NoTrainAnnouncementFound,
NoTrainStationFound,
StationInfoModel,
TrafikverketTrain,
UnknownError,
)
import voluptuous as vol
Expand All @@ -28,16 +26,15 @@
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.selector import (
SelectOptionDict,
SelectSelector,
SelectSelectorConfig,
SelectSelectorMode,
TextSelector,
TimeSelector,
)
import homeassistant.util.dt as dt_util

from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TIME, CONF_TO, DOMAIN
from .util import next_departuredate

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -68,64 +65,42 @@
)


async def validate_input(
async def validate_station(
hass: HomeAssistant,
api_key: str,
train_from: str,
train_to: str,
train_time: str | None,
weekdays: list[str],
product_filter: str | None,
) -> dict[str, str]:
train_station: str,
field: str,
) -> tuple[list[StationInfoModel], dict[str, str]]:
"""Validate input from user input."""
errors: dict[str, str] = {}

when = dt_util.now()
if train_time:
departure_day = next_departuredate(weekdays)
if _time := dt_util.parse_time(train_time):
when = datetime.combine(
departure_day,
_time,
dt_util.get_default_time_zone(),
)

stations = []
try:
web_session = async_get_clientsession(hass)
train_api = TrafikverketTrain(web_session, api_key)
from_station = await train_api.async_search_train_station(train_from)
to_station = await train_api.async_search_train_station(train_to)
if train_time:
await train_api.async_get_train_stop(
from_station, to_station, when, product_filter
)
else:
await train_api.async_get_next_train_stop(
from_station, to_station, when, product_filter
)
stations = await train_api.async_search_train_stations(train_station)
except InvalidAuthentication:
errors["base"] = "invalid_auth"
except NoTrainStationFound:
errors["base"] = "invalid_station"
except MultipleTrainStationsFound:
errors["base"] = "more_stations"
except NoTrainAnnouncementFound:
errors["base"] = "no_trains"
errors[field] = "invalid_station"
except UnknownError as error:
_LOGGER.error("Unknown error occurred during validation %s", str(error))
errors["base"] = "cannot_connect"
except Exception as error: # noqa: BLE001
_LOGGER.error("Unknown exception occurred during validation %s", str(error))
errors["base"] = "cannot_connect"

return errors
return (stations, errors)


class TVTrainConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Trafikverket Train integration."""

VERSION = 1
MINOR_VERSION = 2
VERSION = 2
MINOR_VERSION = 1

_from_stations: list[StationInfoModel]
_to_stations: list[StationInfoModel]
_data: dict[str, Any]

@staticmethod
@callback
Expand All @@ -151,14 +126,11 @@ async def async_step_reauth_confirm(
api_key = user_input[CONF_API_KEY]

reauth_entry = self._get_reauth_entry()
errors = await validate_input(
_, errors = await validate_station(
self.hass,
api_key,
reauth_entry.data[CONF_FROM],
reauth_entry.data[CONF_TO],
reauth_entry.data.get(CONF_TIME),
reauth_entry.data[CONF_WEEKDAY],
reauth_entry.options.get(CONF_FILTER_PRODUCT),
CONF_FROM,
)
if not errors:
return self.async_update_reload_and_abort(
Expand Down Expand Up @@ -193,38 +165,40 @@ async def async_step_user(
if train_time:
name = f"{train_from} to {train_to} at {train_time}"

errors = await validate_input(
self.hass,
api_key,
train_from,
train_to,
train_time,
train_days,
filter_product,
self._from_stations, from_errors = await validate_station(
self.hass, api_key, train_from, CONF_FROM
)
self._to_stations, to_errors = await validate_station(
self.hass, api_key, train_to, CONF_TO
)
errors = {**from_errors, **to_errors}

if not errors:
self._async_abort_entries_match(
{
CONF_API_KEY: api_key,
CONF_FROM: train_from,
CONF_TO: train_to,
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
CONF_FILTER_PRODUCT: filter_product,
}
)
return self.async_create_entry(
title=name,
data={
CONF_API_KEY: api_key,
CONF_NAME: name,
CONF_FROM: train_from,
CONF_TO: train_to,
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
},
options={CONF_FILTER_PRODUCT: filter_product},
)
if len(self._from_stations) == 1 and len(self._to_stations) == 1:
self._async_abort_entries_match(
{
CONF_API_KEY: api_key,
CONF_FROM: self._from_stations[0].signature,
CONF_TO: self._to_stations[0].signature,
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
CONF_FILTER_PRODUCT: filter_product,
}
)
return self.async_create_entry(
title=name,
data={
CONF_API_KEY: api_key,
CONF_NAME: name,
CONF_FROM: self._from_stations[0].signature,
CONF_TO: self._to_stations[0].signature,
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
},
options={CONF_FILTER_PRODUCT: filter_product},
)
self._data = user_input
return await self.async_step_select_stations()

return self.async_show_form(
step_id="user",
Expand All @@ -234,6 +208,77 @@ async def async_step_user(
errors=errors,
)

async def async_step_select_stations(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the select station step."""
if user_input is not None:
api_key: str = self._data[CONF_API_KEY]
train_from: str = user_input[CONF_FROM]
train_to: str = user_input[CONF_TO]
train_time: str | None = self._data.get(CONF_TIME)
train_days: list = self._data[CONF_WEEKDAY]
filter_product: str | None = self._data[CONF_FILTER_PRODUCT]

if filter_product == "":
filter_product = None

name = f"{self._data[CONF_FROM]} to {self._data[CONF_TO]}"
if train_time:
name = (
f"{self._data[CONF_FROM]} to {self._data[CONF_TO]} at {train_time}"
)
self._async_abort_entries_match(
{
CONF_API_KEY: api_key,
CONF_FROM: train_from,
CONF_TO: user_input[CONF_TO],
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
CONF_FILTER_PRODUCT: filter_product,
}
)
return self.async_create_entry(
title=name,
data={
CONF_API_KEY: api_key,
CONF_NAME: name,
CONF_FROM: train_from,
CONF_TO: train_to,
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
},
options={CONF_FILTER_PRODUCT: filter_product},
)
from_options = [
SelectOptionDict(value=station.signature, label=station.station_name)
for station in self._from_stations
]
to_options = [
SelectOptionDict(value=station.signature, label=station.station_name)
for station in self._to_stations
]
schema = {}
if len(from_options) > 1:
schema[vol.Required(CONF_FROM)] = SelectSelector(
SelectSelectorConfig(
options=from_options, mode=SelectSelectorMode.DROPDOWN, sort=True
)
)
if len(to_options) > 1:
schema[vol.Required(CONF_TO)] = SelectSelector(
SelectSelectorConfig(
options=to_options, mode=SelectSelectorMode.DROPDOWN, sort=True
)
)

return self.async_show_form(
step_id="select_stations",
data_schema=self.add_suggested_values_to_schema(
vol.Schema(schema), user_input or {}
),
)


class TVTrainOptionsFlowHandler(OptionsFlow):
"""Handle Trafikverket Train options."""
Expand Down
Loading

0 comments on commit 4709a31

Please sign in to comment.