Skip to content

Commit

Permalink
add time and offset config for connections
Browse files Browse the repository at this point in the history
  • Loading branch information
miaucl committed Jun 24, 2024
1 parent 59dd63e commit 3990743
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 21 deletions.
43 changes: 35 additions & 8 deletions homeassistant/components/swiss_public_transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,21 @@
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession

from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, PLACEHOLDERS
from .const import (
CONF_DESTINATION,
CONF_START,
CONF_TIME,
CONF_TIME_OFFSET,
CONF_VIA,
DOMAIN,
PLACEHOLDERS,
)
from .coordinator import SwissPublicTransportDataUpdateCoordinator
from .helper import unique_id_from_config
from .helper import (
dict_duration_to_str_duration,
offset_opendata,
unique_id_from_config,
)

_LOGGER = logging.getLogger(__name__)

Expand All @@ -33,8 +45,23 @@ async def async_setup_entry(
start = config[CONF_START]
destination = config[CONF_DESTINATION]

time_offset_dict: dict[str, int] | None = config.get(CONF_TIME_OFFSET)
time_offset = (
dict_duration_to_str_duration(time_offset_dict)
if CONF_TIME_OFFSET in config and time_offset_dict is not None
else None
)

session = async_get_clientsession(hass)
opendata = OpendataTransport(start, destination, session, via=config.get(CONF_VIA))
opendata = OpendataTransport(
start,
destination,
session,
via=config.get(CONF_VIA),
time=config.get(CONF_TIME),
)
if time_offset:
offset_opendata(opendata, time_offset)

try:
await opendata.async_get_data()
Expand All @@ -58,7 +85,7 @@ async def async_setup_entry(
},
) from e

coordinator = SwissPublicTransportDataUpdateCoordinator(hass, opendata)
coordinator = SwissPublicTransportDataUpdateCoordinator(hass, opendata, time_offset)
await coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator

Expand All @@ -82,7 +109,7 @@ async def async_migrate_entry(
"""Migrate config entry."""
_LOGGER.debug("Migrating from version %s", config_entry.version)

if config_entry.version > 2:
if config_entry.version > 3:
# This means the user has downgraded from a future version
return False

Expand Down Expand Up @@ -117,9 +144,9 @@ async def async_migrate_entry(
config_entry, unique_id=new_unique_id, minor_version=2
)

if config_entry.version < 2:
# Via stations now available, which are not backwards compatible if used, changes unique id
hass.config_entries.async_update_entry(config_entry, version=2, minor_version=1)
if config_entry.version < 3:
# Via stations and time/offset settings now available, which are not backwards compatible if used, changes unique id
hass.config_entries.async_update_entry(config_entry, version=3, minor_version=1)

_LOGGER.debug(
"Migration to version %s.%s successful",
Expand Down
38 changes: 35 additions & 3 deletions homeassistant/components/swiss_public_transport/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,29 @@
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.selector import (
DurationSelector,
TextSelector,
TextSelectorConfig,
TextSelectorType,
TimeSelector,
)

from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, MAX_VIA, PLACEHOLDERS
from .helper import unique_id_from_config
from .const import (
CONF_DESTINATION,
CONF_IS_ARRIVAL,
CONF_START,
CONF_TIME,
CONF_TIME_OFFSET,
CONF_VIA,
DOMAIN,
MAX_VIA,
PLACEHOLDERS,
)
from .helper import (
dict_duration_to_str_duration,
offset_opendata,
unique_id_from_config,
)

DATA_SCHEMA = vol.Schema(
{
Expand All @@ -32,6 +48,9 @@
),
),
vol.Required(CONF_DESTINATION): cv.string,
vol.Optional(CONF_TIME): TimeSelector(),
vol.Optional(CONF_TIME_OFFSET): DurationSelector(),
vol.Optional(CONF_IS_ARRIVAL): bool,
}
)

Expand All @@ -41,7 +60,7 @@
class SwissPublicTransportConfigFlow(ConfigFlow, domain=DOMAIN):
"""Swiss public transport config flow."""

VERSION = 2
VERSION = 3
MINOR_VERSION = 1

async def async_step_user(
Expand All @@ -56,14 +75,27 @@ async def async_step_user(

if CONF_VIA in user_input and len(user_input[CONF_VIA]) > MAX_VIA:
errors["base"] = "too_many_via_stations"
elif CONF_TIME in user_input and CONF_TIME_OFFSET in user_input:
errors["base"] = "mutex_time_offset"
else:
session = async_get_clientsession(self.hass)
time_offset_dict: dict[str, int] | None = user_input.get(
CONF_TIME_OFFSET
)
time_offset = (
dict_duration_to_str_duration(time_offset_dict)
if CONF_TIME_OFFSET in user_input and time_offset_dict is not None
else None
)
opendata = OpendataTransport(
user_input[CONF_START],
user_input[CONF_DESTINATION],
session,
via=user_input.get(CONF_VIA),
time=user_input.get(CONF_TIME),
)
if time_offset:
offset_opendata(opendata, time_offset)
try:
await opendata.async_get_data()
except OpendataTransportConnectionError:
Expand Down
4 changes: 4 additions & 0 deletions homeassistant/components/swiss_public_transport/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
CONF_DESTINATION: Final = "to"
CONF_START: Final = "from"
CONF_VIA: Final = "via"
CONF_TIME: Final = "time"
CONF_TIME_OFFSET: Final = "time_offset"
CONF_IS_ARRIVAL: Final = "is_arrival"

DEFAULT_NAME = "Next Destination"
DEFAULT_IS_ARRIVAL = False

MAX_VIA = 5
SENSOR_CONNECTIONS_COUNT = 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import homeassistant.util.dt as dt_util

from .const import DOMAIN, SENSOR_CONNECTIONS_COUNT
from .helper import offset_opendata

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,7 +49,9 @@ class SwissPublicTransportDataUpdateCoordinator(

config_entry: ConfigEntry

def __init__(self, hass: HomeAssistant, opendata: OpendataTransport) -> None:
def __init__(
self, hass: HomeAssistant, opendata: OpendataTransport, time_offset: str | None
) -> None:
"""Initialize the SwissPublicTransport data coordinator."""
super().__init__(
hass,
Expand All @@ -57,6 +60,7 @@ def __init__(self, hass: HomeAssistant, opendata: OpendataTransport) -> None:
update_interval=timedelta(seconds=90),
)
self._opendata = opendata
self._time_offset = time_offset

def remaining_time(self, departure) -> timedelta | None:
"""Calculate the remaining time for the departure."""
Expand All @@ -74,6 +78,9 @@ def nth_departure_time(self, i: int) -> datetime | None:
return None

async def _async_update_data(self) -> list[DataConnection]:
if self._time_offset:
offset_opendata(self._opendata, self._time_offset)

try:
await self._opendata.async_get_data()
except OpendataTransportError as e:
Expand Down
49 changes: 44 additions & 5 deletions homeassistant/components/swiss_public_transport/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,52 @@
from types import MappingProxyType
from typing import Any

from .const import CONF_DESTINATION, CONF_START, CONF_VIA
from opendata_transport import OpendataTransport

import homeassistant.util.dt as dt_util

from .const import (
CONF_DESTINATION,
CONF_IS_ARRIVAL,
CONF_START,
CONF_TIME,
CONF_TIME_OFFSET,
CONF_VIA,
DEFAULT_IS_ARRIVAL,
)


def offset_opendata(opendata: OpendataTransport, offset: str) -> None:
"""In place offset the opendata connector."""

duration = dt_util.parse_duration(offset)
if duration:
now_offset = dt_util.as_local(dt_util.now() + duration)
opendata.date = now_offset.date()
opendata.time = now_offset.time()


def dict_duration_to_str_duration(
d: dict[str, int],
) -> str:
"""Build a string from a dict duration."""
return f"{d['hours']:02d}:{d['minutes']:02d}:{d['seconds']:02d}"


def unique_id_from_config(config: MappingProxyType[str, Any] | dict[str, Any]) -> str:
"""Build a unique id from a config entry."""
return f"{config[CONF_START]} {config[CONF_DESTINATION]}" + (
" via " + ", ".join(config[CONF_VIA])
if CONF_VIA in config and len(config[CONF_VIA]) > 0
else ""
return (
f"{config[CONF_START]} {config[CONF_DESTINATION]}"
+ (
" via " + ", ".join(config[CONF_VIA])
if CONF_VIA in config and len(config[CONF_VIA]) > 0
else ""
)
+ (" arrival" if config.get(CONF_IS_ARRIVAL, DEFAULT_IS_ARRIVAL) else "")
+ (" at " + config[CONF_TIME] if CONF_TIME in config else "")
+ (
" in " + dict_duration_to_str_duration(config[CONF_TIME_OFFSET])
if CONF_TIME_OFFSET in config
else ""
)
)
8 changes: 6 additions & 2 deletions homeassistant/components/swiss_public_transport/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"cannot_connect": "Cannot connect to server",
"bad_config": "Request failed due to bad config: Check at [stationboard]({stationboard_url}) if your station names are valid",
"too_many_via_stations": "Too many via stations, only up to 5 via stations are allowed per connection.",
"mutex_time_offset": "Setting a time and offset for a connection are mutually exclusive, please choose only one.",
"unknown": "An unknown error was raised by python-opendata-transport"
},
"abort": {
Expand All @@ -17,9 +18,12 @@
"data": {
"from": "Start station",
"to": "End station",
"via": "List of up to 5 via stations"
"via": "List of up to 5 via stations",
"time": "Select a fixed time of day",
"time_offset": "Select a moving time offset",
"is_arrival": "Use arrival instead of departure for time and offset configuration"
},
"description": "Provide start and end station for your connection,\nand optionally up to 5 via stations.\n\nCheck the [stationboard]({stationboard_url}) for valid stations.",
"description": "Provide start and end station for your connection,\nand optionally up to 5 via stations.\nOptionally, you can also configure connections at a specific time or moving offset for departure or arrival.\n\nCheck the [stationboard]({stationboard_url}) for valid stations.",
"title": "Swiss Public Transport"
}
}
Expand Down
26 changes: 26 additions & 0 deletions tests/components/swiss_public_transport/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from homeassistant.components.swiss_public_transport import config_flow
from homeassistant.components.swiss_public_transport.const import (
CONF_DESTINATION,
CONF_IS_ARRIVAL,
CONF_START,
CONF_TIME,
CONF_TIME_OFFSET,
CONF_VIA,
MAX_VIA,
)
Expand Down Expand Up @@ -43,6 +46,23 @@
CONF_VIA: MOCK_DATA_STEP_ONE_VIA[CONF_VIA] * (MAX_VIA + 1),
}

MOCK_DATA_STEP_TIME = {
**MOCK_DATA_STEP,
CONF_TIME: "18:03:00",
}

MOCK_DATA_STEP_TIME_OFFSET_ARRIVAL = {
**MOCK_DATA_STEP,
CONF_TIME_OFFSET: {"hours": 0, "minutes": 10, "seconds": 0},
CONF_IS_ARRIVAL: True,
}

MOCK_DATA_STEP_TIME_OFFSET_MUTEX = {
**MOCK_DATA_STEP,
CONF_TIME: "18:03:00",
CONF_TIME_OFFSET: {"hours": 0, "minutes": 10, "seconds": 0},
}


@pytest.mark.parametrize(
("user_input", "config_title"),
Expand All @@ -53,6 +73,11 @@
MOCK_DATA_STEP_MANY_VIA,
"test_start test_destination via via_station_1, via_station_2, via_station_3",
),
(MOCK_DATA_STEP_TIME, "test_start test_destination at 18:03:00"),
(
MOCK_DATA_STEP_TIME_OFFSET_ARRIVAL,
"test_start test_destination arrival in 00:10:00",
),
],
)
async def test_flow_user_init_data_success(
Expand Down Expand Up @@ -93,6 +118,7 @@ async def test_flow_user_init_data_success(
(OpendataTransportConnectionError(), "cannot_connect", MOCK_DATA_STEP),
(OpendataTransportError(), "bad_config", MOCK_DATA_STEP),
(None, "too_many_via_stations", MOCK_DATA_STEP_TOO_MANY_STATIONS),
(None, "mutex_time_offset", MOCK_DATA_STEP_TIME_OFFSET_MUTEX),
(IndexError(), "unknown", MOCK_DATA_STEP),
],
)
Expand Down
Loading

0 comments on commit 3990743

Please sign in to comment.