diff --git a/homeassistant/components/swiss_public_transport/__init__.py b/homeassistant/components/swiss_public_transport/__init__.py index bceac6007a261e..628f6e95c2abd6 100644 --- a/homeassistant/components/swiss_public_transport/__init__.py +++ b/homeassistant/components/swiss_public_transport/__init__.py @@ -19,12 +19,22 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.typing import ConfigType -from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, PLACEHOLDERS +from .const import ( + CONF_DESTINATION, + CONF_START, + CONF_TIME_FIXED, + CONF_TIME_OFFSET, + CONF_TIME_STATION, + CONF_VIA, + DEFAULT_TIME_STATION, + DOMAIN, + PLACEHOLDERS, +) from .coordinator import ( SwissPublicTransportConfigEntry, SwissPublicTransportDataUpdateCoordinator, ) -from .helper import unique_id_from_config +from .helper import offset_opendata, unique_id_from_config from .services import setup_services _LOGGER = logging.getLogger(__name__) @@ -50,8 +60,19 @@ async def async_setup_entry( start = config[CONF_START] destination = config[CONF_DESTINATION] + time_offset: dict[str, int] | None = config.get(CONF_TIME_OFFSET) + session = async_get_clientsession(hass) - opendata = OpendataTransport(start, destination, session, via=config.get(CONF_VIA)) + opendata = OpendataTransport( + start, + destination, + session, + via=config.get(CONF_VIA), + time=config.get(CONF_TIME_FIXED), + isArrivalTime=config.get(CONF_TIME_STATION, DEFAULT_TIME_STATION) == "arrival", + ) + if time_offset: + offset_opendata(opendata, time_offset) try: await opendata.async_get_data() @@ -75,7 +96,7 @@ async def async_setup_entry( }, ) from e - coordinator = SwissPublicTransportDataUpdateCoordinator(hass, opendata) + coordinator = SwissPublicTransportDataUpdateCoordinator(hass, opendata, time_offset) await coordinator.async_config_entry_first_refresh() entry.runtime_data = coordinator @@ -96,7 +117,7 @@ async def async_migrate_entry( """Migrate config entry.""" _LOGGER.debug("Migrating from version %s", config_entry.version) - if config_entry.version > 2: + if config_entry.version > 3: # This means the user has downgraded from a future version return False @@ -131,9 +152,9 @@ async def async_migrate_entry( config_entry, unique_id=new_unique_id, minor_version=2 ) - if config_entry.version < 2: - # Via stations now available, which are not backwards compatible if used, changes unique id - hass.config_entries.async_update_entry(config_entry, version=2, minor_version=1) + if config_entry.version < 3: + # Via stations and time/offset settings now available, which are not backwards compatible if used, changes unique id + hass.config_entries.async_update_entry(config_entry, version=3, minor_version=1) _LOGGER.debug( "Migration to version %s.%s successful", diff --git a/homeassistant/components/swiss_public_transport/config_flow.py b/homeassistant/components/swiss_public_transport/config_flow.py index 74c6223f1d99a5..58d674f0c266ec 100644 --- a/homeassistant/components/swiss_public_transport/config_flow.py +++ b/homeassistant/components/swiss_public_transport/config_flow.py @@ -14,15 +14,35 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession import homeassistant.helpers.config_validation as cv from homeassistant.helpers.selector import ( + DurationSelector, + SelectSelector, + SelectSelectorConfig, + SelectSelectorMode, TextSelector, TextSelectorConfig, TextSelectorType, + TimeSelector, ) -from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, MAX_VIA, PLACEHOLDERS -from .helper import unique_id_from_config +from .const import ( + CONF_DESTINATION, + CONF_START, + CONF_TIME_FIXED, + CONF_TIME_MODE, + CONF_TIME_OFFSET, + CONF_TIME_STATION, + CONF_VIA, + DEFAULT_TIME_MODE, + DEFAULT_TIME_STATION, + DOMAIN, + IS_ARRIVAL_OPTIONS, + MAX_VIA, + PLACEHOLDERS, + TIME_MODE_OPTIONS, +) +from .helper import offset_opendata, unique_id_from_config -DATA_SCHEMA = vol.Schema( +USER_DATA_SCHEMA = vol.Schema( { vol.Required(CONF_START): cv.string, vol.Optional(CONF_VIA): TextSelector( @@ -32,8 +52,25 @@ ), ), vol.Required(CONF_DESTINATION): cv.string, + vol.Optional(CONF_TIME_MODE, default=DEFAULT_TIME_MODE): SelectSelector( + SelectSelectorConfig( + options=TIME_MODE_OPTIONS, + mode=SelectSelectorMode.DROPDOWN, + translation_key="time_mode", + ), + ), + vol.Optional(CONF_TIME_STATION, default=DEFAULT_TIME_STATION): SelectSelector( + SelectSelectorConfig( + options=IS_ARRIVAL_OPTIONS, + mode=SelectSelectorMode.DROPDOWN, + translation_key="time_station", + ), + ), } ) +ADVANCED_TIME_DATA_SCHEMA = {vol.Optional(CONF_TIME_FIXED): TimeSelector()} +ADVANCED_TIME_OFFSET_DATA_SCHEMA = {vol.Optional(CONF_TIME_OFFSET): DurationSelector()} + _LOGGER = logging.getLogger(__name__) @@ -41,39 +78,33 @@ class SwissPublicTransportConfigFlow(ConfigFlow, domain=DOMAIN): """Swiss public transport config flow.""" - VERSION = 2 + VERSION = 3 MINOR_VERSION = 1 + user_input: dict[str, Any] + async def async_step_user( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Async user step to set up the connection.""" errors: dict[str, str] = {} if user_input is not None: - unique_id = unique_id_from_config(user_input) - await self.async_set_unique_id(unique_id) - self._abort_if_unique_id_configured() - if CONF_VIA in user_input and len(user_input[CONF_VIA]) > MAX_VIA: errors["base"] = "too_many_via_stations" else: - session = async_get_clientsession(self.hass) - opendata = OpendataTransport( - user_input[CONF_START], - user_input[CONF_DESTINATION], - session, - via=user_input.get(CONF_VIA), - ) - try: - await opendata.async_get_data() - except OpendataTransportConnectionError: - errors["base"] = "cannot_connect" - except OpendataTransportError: - errors["base"] = "bad_config" - except Exception: # pylint: disable=broad-except - _LOGGER.exception("Unknown error") - errors["base"] = "unknown" + err = await self.fetch_connections(user_input) + if err: + errors["base"] = err else: + self.user_input = user_input + if user_input[CONF_TIME_MODE] == "fixed": + return await self.async_step_time_fixed() + if user_input[CONF_TIME_MODE] == "offset": + return await self.async_step_time_offset() + + unique_id = unique_id_from_config(user_input) + await self.async_set_unique_id(unique_id) + self._abort_if_unique_id_configured() return self.async_create_entry( title=unique_id, data=user_input, @@ -81,7 +112,85 @@ async def async_step_user( return self.async_show_form( step_id="user", - data_schema=DATA_SCHEMA, + data_schema=self.add_suggested_values_to_schema( + data_schema=USER_DATA_SCHEMA, + suggested_values=user_input, + ), errors=errors, description_placeholders=PLACEHOLDERS, ) + + async def async_step_time_fixed( + self, time_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Async time step to set up the connection.""" + return await self._async_step_time_mode( + CONF_TIME_FIXED, vol.Schema(ADVANCED_TIME_DATA_SCHEMA), time_input + ) + + async def async_step_time_offset( + self, time_offset_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Async time offset step to set up the connection.""" + return await self._async_step_time_mode( + CONF_TIME_OFFSET, + vol.Schema(ADVANCED_TIME_OFFSET_DATA_SCHEMA), + time_offset_input, + ) + + async def _async_step_time_mode( + self, + step_id: str, + time_mode_schema: vol.Schema, + time_mode_input: dict[str, Any] | None = None, + ) -> ConfigFlowResult: + """Async time mode step to set up the connection.""" + errors: dict[str, str] = {} + if time_mode_input is not None: + unique_id = unique_id_from_config({**self.user_input, **time_mode_input}) + await self.async_set_unique_id(unique_id) + self._abort_if_unique_id_configured() + + err = await self.fetch_connections( + {**self.user_input, **time_mode_input}, + time_mode_input.get(CONF_TIME_OFFSET), + ) + if err: + errors["base"] = err + else: + return self.async_create_entry( + title=unique_id, + data={**self.user_input, **time_mode_input}, + ) + + return self.async_show_form( + step_id=step_id, + data_schema=time_mode_schema, + errors=errors, + description_placeholders=PLACEHOLDERS, + ) + + async def fetch_connections( + self, input: dict[str, Any], time_offset: dict[str, int] | None = None + ) -> str | None: + """Fetch the connections and advancedly return an error.""" + try: + session = async_get_clientsession(self.hass) + opendata = OpendataTransport( + input[CONF_START], + input[CONF_DESTINATION], + session, + via=input.get(CONF_VIA), + time=input.get(CONF_TIME_FIXED), + ) + if time_offset: + offset_opendata(opendata, time_offset) + await opendata.async_get_data() + except OpendataTransportConnectionError: + return "cannot_connect" + except OpendataTransportError: + return "bad_config" + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Unknown error") + return "unknown" + return None diff --git a/homeassistant/components/swiss_public_transport/const.py b/homeassistant/components/swiss_public_transport/const.py index c02f36f2f25131..10bfc0d03555bf 100644 --- a/homeassistant/components/swiss_public_transport/const.py +++ b/homeassistant/components/swiss_public_transport/const.py @@ -7,13 +7,21 @@ CONF_DESTINATION: Final = "to" CONF_START: Final = "from" CONF_VIA: Final = "via" +CONF_TIME_STATION: Final = "time_station" +CONF_TIME_MODE: Final = "time_mode" +CONF_TIME_FIXED: Final = "time_fixed" +CONF_TIME_OFFSET: Final = "time_offset" DEFAULT_NAME = "Next Destination" DEFAULT_UPDATE_TIME = 90 +DEFAULT_TIME_STATION = "departure" +DEFAULT_TIME_MODE = "now" MAX_VIA = 5 CONNECTIONS_COUNT = 3 CONNECTIONS_MAX = 15 +IS_ARRIVAL_OPTIONS = ["departure", "arrival"] +TIME_MODE_OPTIONS = ["now", "fixed", "offset"] PLACEHOLDERS = { diff --git a/homeassistant/components/swiss_public_transport/coordinator.py b/homeassistant/components/swiss_public_transport/coordinator.py index e6413e6f772692..59602e7b982e46 100644 --- a/homeassistant/components/swiss_public_transport/coordinator.py +++ b/homeassistant/components/swiss_public_transport/coordinator.py @@ -19,6 +19,7 @@ from homeassistant.util.json import JsonValueType from .const import CONNECTIONS_COUNT, DEFAULT_UPDATE_TIME, DOMAIN +from .helper import offset_opendata _LOGGER = logging.getLogger(__name__) @@ -57,7 +58,12 @@ class SwissPublicTransportDataUpdateCoordinator( config_entry: SwissPublicTransportConfigEntry - def __init__(self, hass: HomeAssistant, opendata: OpendataTransport) -> None: + def __init__( + self, + hass: HomeAssistant, + opendata: OpendataTransport, + time_offset: dict[str, int] | None, + ) -> None: """Initialize the SwissPublicTransport data coordinator.""" super().__init__( hass, @@ -66,6 +72,7 @@ def __init__(self, hass: HomeAssistant, opendata: OpendataTransport) -> None: update_interval=timedelta(seconds=DEFAULT_UPDATE_TIME), ) self._opendata = opendata + self._time_offset = time_offset def remaining_time(self, departure) -> timedelta | None: """Calculate the remaining time for the departure.""" @@ -81,6 +88,9 @@ async def _async_update_data(self) -> list[DataConnection]: async def fetch_connections(self, limit: int) -> list[DataConnection]: """Fetch connections using the opendata api.""" self._opendata.limit = limit + if self._time_offset: + offset_opendata(self._opendata, self._time_offset) + try: await self._opendata.async_get_data() except OpendataTransportConnectionError as e: diff --git a/homeassistant/components/swiss_public_transport/helper.py b/homeassistant/components/swiss_public_transport/helper.py index af03f7ad193a25..704479b77d6b72 100644 --- a/homeassistant/components/swiss_public_transport/helper.py +++ b/homeassistant/components/swiss_public_transport/helper.py @@ -1,15 +1,59 @@ """Helper functions for swiss_public_transport.""" +from datetime import timedelta from types import MappingProxyType from typing import Any -from .const import CONF_DESTINATION, CONF_START, CONF_VIA +from opendata_transport import OpendataTransport + +import homeassistant.util.dt as dt_util + +from .const import ( + CONF_DESTINATION, + CONF_START, + CONF_TIME_FIXED, + CONF_TIME_OFFSET, + CONF_TIME_STATION, + CONF_VIA, + DEFAULT_TIME_STATION, +) + + +def offset_opendata(opendata: OpendataTransport, offset: dict[str, int]) -> None: + """In place offset the opendata connector.""" + + duration = timedelta(**offset) + if duration: + now_offset = dt_util.as_local(dt_util.now() + duration) + opendata.date = now_offset.date() + opendata.time = now_offset.time() + + +def dict_duration_to_str_duration( + d: dict[str, int], +) -> str: + """Build a string from a dict duration.""" + return f"{d['hours']:02d}:{d['minutes']:02d}:{d['seconds']:02d}" def unique_id_from_config(config: MappingProxyType[str, Any] | dict[str, Any]) -> str: """Build a unique id from a config entry.""" - return f"{config[CONF_START]} {config[CONF_DESTINATION]}" + ( - " via " + ", ".join(config[CONF_VIA]) - if CONF_VIA in config and len(config[CONF_VIA]) > 0 - else "" + return ( + f"{config[CONF_START]} {config[CONF_DESTINATION]}" + + ( + " via " + ", ".join(config[CONF_VIA]) + if CONF_VIA in config and len(config[CONF_VIA]) > 0 + else "" + ) + + ( + " arrival" + if config.get(CONF_TIME_STATION, DEFAULT_TIME_STATION) == "arrival" + else "" + ) + + (" at " + config[CONF_TIME_FIXED] if CONF_TIME_FIXED in config else "") + + ( + " in " + dict_duration_to_str_duration(config[CONF_TIME_OFFSET]) + if CONF_TIME_OFFSET in config + else "" + ) ) diff --git a/homeassistant/components/swiss_public_transport/strings.json b/homeassistant/components/swiss_public_transport/strings.json index b3bfd9aea8ff02..91645b2fee4a4d 100644 --- a/homeassistant/components/swiss_public_transport/strings.json +++ b/homeassistant/components/swiss_public_transport/strings.json @@ -17,10 +17,30 @@ "data": { "from": "Start station", "to": "End station", - "via": "List of up to 5 via stations" + "via": "List of up to 5 via stations", + "time_station": "Select the relevant station", + "time_mode": "Select a time mode" + }, + "data_description": { + "time_station": "Usually the departure time of a connection when it leaves the start station is tracked. Alternatively, track the time when the connection arrives at its end station.", + "time_mode": "Time mode lets you change the departure timing and fix it to a specific time (e.g. 7:12:00 AM every morning) or add a moving offset (e.g. +00:05:00 taking into account the time to walk to the station)." }, "description": "Provide start and end station for your connection,\nand optionally up to 5 via stations.\n\nCheck the [stationboard]({stationboard_url}) for valid stations.", "title": "Swiss Public Transport" + }, + "time_fixed": { + "data": { + "time_fixed": "Time of day" + }, + "description": "Please select the relevant time for the connection (e.g. 7:12:00 AM every morning).", + "title": "Swiss Public Transport" + }, + "time_offset": { + "data": { + "time_offset": "Time offset" + }, + "description": "Please select the relevant offset to add to the earliest possible connection (e.g. add +00:05:00 offset, taking into account the time to walk to the station)", + "title": "Swiss Public Transport" } } }, @@ -84,5 +104,20 @@ "config_entry_not_found": { "message": "Swiss public transport integration instance \"{target}\" not found." } + }, + "selector": { + "time_station": { + "options": { + "departure": "Show departure time from start station", + "arrival": "Show arrival time at end station" + } + }, + "time_mode": { + "options": { + "now": "Now", + "fixed": "At a fixed time of day", + "offset": "At an offset from now" + } + } } } diff --git a/tests/components/swiss_public_transport/test_config_flow.py b/tests/components/swiss_public_transport/test_config_flow.py index 027336e28a675a..7c17b0d4c30668 100644 --- a/tests/components/swiss_public_transport/test_config_flow.py +++ b/tests/components/swiss_public_transport/test_config_flow.py @@ -12,6 +12,10 @@ from homeassistant.components.swiss_public_transport.const import ( CONF_DESTINATION, CONF_START, + CONF_TIME_FIXED, + CONF_TIME_MODE, + CONF_TIME_OFFSET, + CONF_TIME_STATION, CONF_VIA, MAX_VIA, ) @@ -23,40 +27,86 @@ pytestmark = pytest.mark.usefixtures("mock_setup_entry") -MOCK_DATA_STEP = { +MOCK_USER_DATA_STEP = { CONF_START: "test_start", CONF_DESTINATION: "test_destination", + CONF_TIME_STATION: "departure", + CONF_TIME_MODE: "now", } -MOCK_DATA_STEP_ONE_VIA = { - **MOCK_DATA_STEP, +MOCK_USER_DATA_STEP_ONE_VIA = { + **MOCK_USER_DATA_STEP, CONF_VIA: ["via_station"], } -MOCK_DATA_STEP_MANY_VIA = { - **MOCK_DATA_STEP, +MOCK_USER_DATA_STEP_MANY_VIA = { + **MOCK_USER_DATA_STEP, CONF_VIA: ["via_station_1", "via_station_2", "via_station_3"], } -MOCK_DATA_STEP_TOO_MANY_STATIONS = { - **MOCK_DATA_STEP, - CONF_VIA: MOCK_DATA_STEP_ONE_VIA[CONF_VIA] * (MAX_VIA + 1), +MOCK_USER_DATA_STEP_TOO_MANY_STATIONS = { + **MOCK_USER_DATA_STEP, + CONF_VIA: MOCK_USER_DATA_STEP_ONE_VIA[CONF_VIA] * (MAX_VIA + 1), +} + +MOCK_USER_DATA_STEP_ARRIVAL = { + **MOCK_USER_DATA_STEP, + CONF_TIME_STATION: "arrival", +} + +MOCK_USER_DATA_STEP_TIME_FIXED = { + **MOCK_USER_DATA_STEP, + CONF_TIME_MODE: "fixed", +} + +MOCK_USER_DATA_STEP_TIME_FIXED_OFFSET = { + **MOCK_USER_DATA_STEP, + CONF_TIME_MODE: "offset", +} + +MOCK_USER_DATA_STEP_BAD = { + **MOCK_USER_DATA_STEP, + CONF_TIME_MODE: "bad", +} + +MOCK_ADVANCED_DATA_STEP_TIME = { + CONF_TIME_FIXED: "18:03:00", +} + +MOCK_ADVANCED_DATA_STEP_TIME_OFFSET = { + CONF_TIME_OFFSET: {"hours": 0, "minutes": 10, "seconds": 0}, } @pytest.mark.parametrize( - ("user_input", "config_title"), + ("user_input", "time_mode_input", "config_title"), [ - (MOCK_DATA_STEP, "test_start test_destination"), - (MOCK_DATA_STEP_ONE_VIA, "test_start test_destination via via_station"), + (MOCK_USER_DATA_STEP, None, "test_start test_destination"), ( - MOCK_DATA_STEP_MANY_VIA, + MOCK_USER_DATA_STEP_ONE_VIA, + None, + "test_start test_destination via via_station", + ), + ( + MOCK_USER_DATA_STEP_MANY_VIA, + None, "test_start test_destination via via_station_1, via_station_2, via_station_3", ), + (MOCK_USER_DATA_STEP_ARRIVAL, None, "test_start test_destination arrival"), + ( + MOCK_USER_DATA_STEP_TIME_FIXED, + MOCK_ADVANCED_DATA_STEP_TIME, + "test_start test_destination at 18:03:00", + ), + ( + MOCK_USER_DATA_STEP_TIME_FIXED_OFFSET, + MOCK_ADVANCED_DATA_STEP_TIME_OFFSET, + "test_start test_destination in 00:10:00", + ), ], ) async def test_flow_user_init_data_success( - hass: HomeAssistant, user_input, config_title + hass: HomeAssistant, user_input, time_mode_input, config_title ) -> None: """Test success response.""" result = await hass.config_entries.flow.async_init( @@ -66,48 +116,56 @@ async def test_flow_user_init_data_success( assert result["type"] is FlowResultType.FORM assert result["step_id"] == "user" assert result["handler"] == "swiss_public_transport" - assert result["data_schema"] == config_flow.DATA_SCHEMA + assert result["data_schema"] == config_flow.USER_DATA_SCHEMA with patch( "homeassistant.components.swiss_public_transport.config_flow.OpendataTransport.async_get_data", autospec=True, return_value=True, ): - result = await hass.config_entries.flow.async_init( - config_flow.DOMAIN, context={"source": "user"} - ) result = await hass.config_entries.flow.async_configure( result["flow_id"], user_input=user_input, ) + if time_mode_input: + assert result["type"] == FlowResultType.FORM + if CONF_TIME_FIXED in time_mode_input: + assert result["step_id"] == "time_fixed" + if CONF_TIME_OFFSET in time_mode_input: + assert result["step_id"] == "time_offset" + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=time_mode_input, + ) + assert result["type"] == FlowResultType.CREATE_ENTRY assert result["result"].title == config_title - assert result["data"] == user_input + assert result["data"] == {**user_input, **(time_mode_input or {})} @pytest.mark.parametrize( ("raise_error", "text_error", "user_input_error"), [ - (OpendataTransportConnectionError(), "cannot_connect", MOCK_DATA_STEP), - (OpendataTransportError(), "bad_config", MOCK_DATA_STEP), - (None, "too_many_via_stations", MOCK_DATA_STEP_TOO_MANY_STATIONS), - (IndexError(), "unknown", MOCK_DATA_STEP), + (OpendataTransportConnectionError(), "cannot_connect", MOCK_USER_DATA_STEP), + (OpendataTransportError(), "bad_config", MOCK_USER_DATA_STEP), + (None, "too_many_via_stations", MOCK_USER_DATA_STEP_TOO_MANY_STATIONS), + (IndexError(), "unknown", MOCK_USER_DATA_STEP), ], ) -async def test_flow_user_init_data_error_and_recover( +async def test_flow_user_init_data_error_and_recover_on_step_1( hass: HomeAssistant, raise_error, text_error, user_input_error ) -> None: - """Test unknown errors.""" + """Test errors in user step.""" + result = await hass.config_entries.flow.async_init( + config_flow.DOMAIN, context={"source": "user"} + ) with patch( "homeassistant.components.swiss_public_transport.config_flow.OpendataTransport.async_get_data", autospec=True, side_effect=raise_error, ) as mock_OpendataTransport: - result = await hass.config_entries.flow.async_init( - config_flow.DOMAIN, context={"source": "user"} - ) result = await hass.config_entries.flow.async_configure( result["flow_id"], user_input=user_input_error, @@ -121,13 +179,75 @@ async def test_flow_user_init_data_error_and_recover( mock_OpendataTransport.return_value = True result = await hass.config_entries.flow.async_configure( result["flow_id"], - user_input=MOCK_DATA_STEP, + user_input=MOCK_USER_DATA_STEP, ) assert result["type"] == FlowResultType.CREATE_ENTRY assert result["result"].title == "test_start test_destination" - assert result["data"] == MOCK_DATA_STEP + assert result["data"] == MOCK_USER_DATA_STEP + + +@pytest.mark.parametrize( + ("raise_error", "text_error", "user_input"), + [ + ( + OpendataTransportConnectionError(), + "cannot_connect", + MOCK_ADVANCED_DATA_STEP_TIME, + ), + (OpendataTransportError(), "bad_config", MOCK_ADVANCED_DATA_STEP_TIME), + (IndexError(), "unknown", MOCK_ADVANCED_DATA_STEP_TIME), + ], +) +async def test_flow_user_init_data_error_and_recover_on_step_2( + hass: HomeAssistant, raise_error, text_error, user_input +) -> None: + """Test errors in time mode step.""" + result = await hass.config_entries.flow.async_init( + config_flow.DOMAIN, context={"source": "user"} + ) + + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "user" + assert result["handler"] == "swiss_public_transport" + assert result["data_schema"] == config_flow.USER_DATA_SCHEMA + + with patch( + "homeassistant.components.swiss_public_transport.config_flow.OpendataTransport.async_get_data", + autospec=True, + return_value=True, + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=MOCK_USER_DATA_STEP_TIME_FIXED, + ) + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "time_fixed" + + with patch( + "homeassistant.components.swiss_public_transport.config_flow.OpendataTransport.async_get_data", + autospec=True, + side_effect=raise_error, + ) as mock_OpendataTransport: + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=user_input, + ) + + assert result["type"] is FlowResultType.FORM + assert result["errors"]["base"] == text_error + + # Recover + mock_OpendataTransport.side_effect = None + mock_OpendataTransport.return_value = True + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=user_input, + ) + + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["result"].title == "test_start test_destination at 18:03:00" async def test_flow_user_init_data_already_configured(hass: HomeAssistant) -> None: @@ -135,8 +255,8 @@ async def test_flow_user_init_data_already_configured(hass: HomeAssistant) -> No entry = MockConfigEntry( domain=config_flow.DOMAIN, - data=MOCK_DATA_STEP, - unique_id=unique_id_from_config(MOCK_DATA_STEP), + data=MOCK_USER_DATA_STEP, + unique_id=unique_id_from_config(MOCK_USER_DATA_STEP), ) entry.add_to_hass(hass) @@ -151,7 +271,7 @@ async def test_flow_user_init_data_already_configured(hass: HomeAssistant) -> No result = await hass.config_entries.flow.async_configure( result["flow_id"], - user_input=MOCK_DATA_STEP, + user_input=MOCK_USER_DATA_STEP, ) assert result["type"] is FlowResultType.ABORT diff --git a/tests/components/swiss_public_transport/test_init.py b/tests/components/swiss_public_transport/test_init.py index 9ad4a8d50b0833..963f5e6fa40a8f 100644 --- a/tests/components/swiss_public_transport/test_init.py +++ b/tests/components/swiss_public_transport/test_init.py @@ -7,6 +7,9 @@ from homeassistant.components.swiss_public_transport.const import ( CONF_DESTINATION, CONF_START, + CONF_TIME_FIXED, + CONF_TIME_OFFSET, + CONF_TIME_STATION, CONF_VIA, DOMAIN, ) @@ -28,6 +31,17 @@ CONF_VIA: ["via_station"], } +MOCK_DATA_STEP_TIME_FIXED = { + **MOCK_DATA_STEP_VIA, + CONF_TIME_FIXED: "18:03:00", +} + +MOCK_DATA_STEP_TIME_OFFSET = { + **MOCK_DATA_STEP_VIA, + CONF_TIME_OFFSET: {"hours": 0, "minutes": 10, "seconds": 0}, + CONF_TIME_STATION: "arrival", +} + CONNECTIONS = [ { "departure": "2024-01-06T18:03:00+0100", @@ -70,6 +84,8 @@ (1, 1, MOCK_DATA_STEP_BASE, "None_departure"), (1, 2, MOCK_DATA_STEP_BASE, None), (2, 1, MOCK_DATA_STEP_VIA, None), + (3, 1, MOCK_DATA_STEP_TIME_FIXED, None), + (3, 1, MOCK_DATA_STEP_TIME_OFFSET, None), ], ) async def test_migration_from( @@ -113,7 +129,7 @@ async def test_migration_from( ) # Check change in config entry and verify most recent version - assert config_entry.version == 2 + assert config_entry.version == 3 assert config_entry.minor_version == 1 assert config_entry.unique_id == unique_id @@ -130,7 +146,7 @@ async def test_migrate_error_from_future(hass: HomeAssistant) -> None: mock_entry = MockConfigEntry( domain=DOMAIN, - version=3, + version=4, minor_version=1, unique_id="some_crazy_future_unique_id", data=MOCK_DATA_STEP_BASE,