diff --git a/homeassistant/components/axis/camera.py b/homeassistant/components/axis/camera.py index 649e512718c2c6..c4cc5df68a0216 100644 --- a/homeassistant/components/axis/camera.py +++ b/homeassistant/components/axis/camera.py @@ -32,6 +32,9 @@ async def async_setup_entry(hass, config_entry, async_add_entities): device = hass.data[AXIS_DOMAIN][config_entry.unique_id] + if not device.option_camera: + return + config = { CONF_NAME: config_entry.data[CONF_NAME], CONF_USERNAME: config_entry.data[CONF_USERNAME], diff --git a/homeassistant/components/axis/const.py b/homeassistant/components/axis/const.py index 1d52677b30cdd9..05a1211f89d0c6 100644 --- a/homeassistant/components/axis/const.py +++ b/homeassistant/components/axis/const.py @@ -1,6 +1,10 @@ """Constants for the Axis component.""" import logging +from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN +from homeassistant.components.camera import DOMAIN as CAMERA_DOMAIN +from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN + LOGGER = logging.getLogger(__package__) DOMAIN = "axis" @@ -13,3 +17,5 @@ DEFAULT_EVENTS = True DEFAULT_TRIGGER_TIME = 0 + +PLATFORMS = [BINARY_SENSOR_DOMAIN, CAMERA_DOMAIN, SWITCH_DOMAIN] diff --git a/homeassistant/components/axis/device.py b/homeassistant/components/axis/device.py index 57d2d1be5d71b9..3483bfbea2eb21 100644 --- a/homeassistant/components/axis/device.py +++ b/homeassistant/components/axis/device.py @@ -7,9 +7,6 @@ from axis.event_stream import OPERATION_INITIALIZED from axis.streammanager import SIGNAL_PLAYING -from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN -from homeassistant.components.camera import DOMAIN as CAMERA_DOMAIN -from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN from homeassistant.const import ( CONF_HOST, CONF_NAME, @@ -32,6 +29,7 @@ DEFAULT_TRIGGER_TIME, DOMAIN as AXIS_DOMAIN, LOGGER, + PLATFORMS, ) from .errors import AuthenticationRequired, CannotConnect @@ -165,38 +163,28 @@ async def async_setup(self): self.fw_version = self.api.vapix.params.firmware_version self.product_type = self.api.vapix.params.prodtype - if self.option_camera: - - self.hass.async_create_task( - self.hass.config_entries.async_forward_entry_setup( - self.config_entry, CAMERA_DOMAIN - ) - ) - - if self.option_events: - - self.api.stream.connection_status_callback = ( - self.async_connection_status_callback + async def start_platforms(): + await asyncio.gather( + *[ + self.hass.config_entries.async_forward_entry_setup( + self.config_entry, platform + ) + for platform in PLATFORMS + ] ) - self.api.enable_events(event_callback=self.async_event_callback) - - platform_tasks = [ - self.hass.config_entries.async_forward_entry_setup( - self.config_entry, platform + if self.option_events: + self.api.stream.connection_status_callback = ( + self.async_connection_status_callback ) - for platform in [BINARY_SENSOR_DOMAIN, SWITCH_DOMAIN] - ] - self.hass.async_create_task(self.start(platform_tasks)) + self.api.enable_events(event_callback=self.async_event_callback) + self.api.start() + + self.hass.async_create_task(start_platforms()) self.config_entry.add_update_listener(self.async_new_address_callback) return True - async def start(self, platform_tasks): - """Start the event stream when all platforms are loaded.""" - await asyncio.gather(*platform_tasks) - self.api.start() - @callback def shutdown(self, event): """Stop the event stream.""" @@ -204,29 +192,23 @@ def shutdown(self, event): async def async_reset(self): """Reset this device to default state.""" - platform_tasks = [] + self.api.stop() - if self.config_entry.options[CONF_CAMERA]: - platform_tasks.append( - self.hass.config_entries.async_forward_entry_unload( - self.config_entry, CAMERA_DOMAIN - ) + unload_ok = all( + await asyncio.gather( + *[ + self.hass.config_entries.async_forward_entry_unload( + self.config_entry, platform + ) + for platform in PLATFORMS + ] ) + ) + if not unload_ok: + return False - if self.config_entry.options[CONF_EVENTS]: - self.api.stop() - platform_tasks += [ - self.hass.config_entries.async_forward_entry_unload( - self.config_entry, platform - ) - for platform in [BINARY_SENSOR_DOMAIN, SWITCH_DOMAIN] - ] - - await asyncio.gather(*platform_tasks) - - for unsub_dispatcher in self.listeners: - unsub_dispatcher() - self.listeners = [] + for unsubscribe_listener in self.listeners: + unsubscribe_listener() return True diff --git a/tests/components/axis/test_device.py b/tests/components/axis/test_device.py index ec350695e6b39b..facb6f7de42cb7 100644 --- a/tests/components/axis/test_device.py +++ b/tests/components/axis/test_device.py @@ -110,7 +110,7 @@ def mock_update_properties(self): await hass.config_entries.async_setup(config_entry.entry_id) await hass.async_block_till_done() - return hass.data[AXIS_DOMAIN].get(config[CONF_MAC]) + return hass.data[AXIS_DOMAIN].get(config_entry.unique_id) async def test_device_setup(hass): @@ -124,8 +124,8 @@ async def test_device_setup(hass): entry = device.config_entry assert len(forward_entry_setup.mock_calls) == 3 - assert forward_entry_setup.mock_calls[0][1] == (entry, "camera") - assert forward_entry_setup.mock_calls[1][1] == (entry, "binary_sensor") + assert forward_entry_setup.mock_calls[0][1] == (entry, "binary_sensor") + assert forward_entry_setup.mock_calls[1][1] == (entry, "camera") assert forward_entry_setup.mock_calls[2][1] == (entry, "switch") assert device.host == ENTRY_CONFIG[CONF_HOST]