From d231f5097de28f0bf1b50a24a7843a3c6e2342b9 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 26 Aug 2016 00:30:56 -0700 Subject: [PATCH] Allow reloading automation --- .../components/automation/__init__.py | 110 +++++++++++---- homeassistant/helpers/entity.py | 4 + homeassistant/helpers/entity_component.py | 61 +++++--- tests/components/automation/test_init.py | 130 ++++++++++++++++++ 4 files changed, 264 insertions(+), 41 deletions(-) diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index fe443515e8aee..f42de618c38f5 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -10,6 +10,7 @@ import voluptuous as vol from homeassistant.bootstrap import prepare_setup_platform +from homeassistant import config as conf_util from homeassistant.const import ( ATTR_ENTITY_ID, CONF_PLATFORM, STATE_ON, SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE) @@ -46,6 +47,7 @@ ATTR_LAST_TRIGGERED = 'last_triggered' ATTR_VARIABLES = 'variables' SERVICE_TRIGGER = 'trigger' +SERVICE_RELOAD = 'reload' _LOGGER = logging.getLogger(__name__) @@ -112,6 +114,8 @@ def validator(config): vol.Optional(ATTR_VARIABLES, default={}): dict, }) +RELOAD_SERVICE_SCHEMA = vol.Schema({}) + def is_on(hass, entity_id=None): """ @@ -148,35 +152,16 @@ def trigger(hass, entity_id=None): hass.services.call(DOMAIN, SERVICE_TRIGGER, data) +def reload(hass): + """Reload the automation from config.""" + hass.services.call(DOMAIN, SERVICE_RELOAD) + + def setup(hass, config): """Setup the automation.""" component = EntityComponent(_LOGGER, DOMAIN, hass) - success = False - for config_key in extract_domain_configs(config, DOMAIN): - conf = config[config_key] - - for list_no, config_block in enumerate(conf): - name = config_block.get(CONF_ALIAS) or "{} {}".format(config_key, - list_no) - - action = _get_action(hass, config_block.get(CONF_ACTION, {}), name) - - if CONF_CONDITION in config_block: - cond_func = _process_if(hass, config, config_block) - - if cond_func is None: - continue - else: - def cond_func(variables): - """Condition will always pass.""" - return True - - attach_triggers = partial(_process_trigger, hass, config, - config_block.get(CONF_TRIGGER, []), name) - entity = AutomationEntity(name, attach_triggers, cond_func, action) - component.add_entities((entity,)) - success = True + success = _process_config(hass, config, component) if not success: return False @@ -191,9 +176,47 @@ def service_handler(service_call): for entity in component.extract_from_service(service_call): getattr(entity, service_call.service)() + def reload_service_handler(service_call): + """Remove all automations and load new ones from config.""" + try: + path = conf_util.find_config_file(hass.config.config_dir) + conf = conf_util.load_yaml_config_file(path) + except HomeAssistantError as err: + _LOGGER.error(err) + return + + # For now copied from bootstrap.py + # Depends on work by @Kellerza to split this out + from homeassistant.bootstrap import config_per_platform, log_exception + + platforms = [] + for _, p_config in config_per_platform(conf, DOMAIN): + # Validate component specific platform schema + try: + p_validated = PLATFORM_SCHEMA(p_config) + except vol.MultipleInvalid as ex: + log_exception(ex, DOMAIN, p_config) + return + + platforms.append(p_validated) + + # Create a copy of the configuration with all config for current + # component removed and add validated config back in. + filter_keys = extract_domain_configs(conf, DOMAIN) + conf = {key: value for key, value in conf.items() + if key not in filter_keys} + conf[DOMAIN] = platforms + # End copied from bootstrap + + component.reset() + _process_config(hass, conf, component) + hass.services.register(DOMAIN, SERVICE_TRIGGER, trigger_service_handler, schema=TRIGGER_SERVICE_SCHEMA) + hass.services.register(DOMAIN, SERVICE_RELOAD, reload_service_handler, + schema=RELOAD_SERVICE_SCHEMA) + for service in (SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE): hass.services.register(DOMAIN, service, service_handler, schema=SERVICE_SCHEMA) @@ -262,6 +285,43 @@ def trigger(self, variables): self._last_triggered = utcnow() self.update_ha_state() + def remove(self): + """Remove automation from HASS.""" + self.turn_off() + super().remove() + + +def _process_config(hass, config, component): + """Process config and add automations.""" + success = False + + for config_key in extract_domain_configs(config, DOMAIN): + conf = config[config_key] + + for list_no, config_block in enumerate(conf): + name = config_block.get(CONF_ALIAS) or "{} {}".format(config_key, + list_no) + + action = _get_action(hass, config_block.get(CONF_ACTION, {}), name) + + if CONF_CONDITION in config_block: + cond_func = _process_if(hass, config, config_block) + + if cond_func is None: + continue + else: + def cond_func(variables): + """Condition will always pass.""" + return True + + attach_triggers = partial(_process_trigger, hass, config, + config_block.get(CONF_TRIGGER, []), name) + entity = AutomationEntity(name, attach_triggers, cond_func, action) + component.add_entities((entity,)) + success = True + + return success + def _get_action(hass, config, name): """Return an action based on a configuration.""" diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 61cda43d431a0..0b4768b809d55 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -195,6 +195,10 @@ def update_ha_state(self, force_refresh=False): return self.hass.states.set( self.entity_id, state, attr, self.force_update) + def remove(self) -> None: + """Remove entitiy from HASS.""" + self.hass.states.remove(self.entity_id) + def _attr_setter(self, name, typ, attr, attrs): """Helper method to populate attributes based on properties.""" if attr in attrs: diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index 898a445c78807..e4d5159682208 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -32,13 +32,14 @@ def __init__(self, logger, domain, hass, self.entities = {} self.group = None - self.is_polling = False self.config = None self.lock = Lock() - self.add_entities = EntityPlatform(self, self.scan_interval, - None).add_entities + self._platforms = { + 'core': EntityPlatform(self, self.scan_interval, None), + } + self.add_entities = self._platforms['core'].add_entities def setup(self, config): """Set up a full entity component. @@ -85,17 +86,23 @@ def _setup_platform(self, platform_type, platform_config, return # Config > Platform > Component - scan_interval = platform_config.get( - CONF_SCAN_INTERVAL, - getattr(platform, 'SCAN_INTERVAL', self.scan_interval)) + scan_interval = (platform_config.get(CONF_SCAN_INTERVAL) or + getattr(platform, 'SCAN_INTERVAL', None) or + self.scan_interval) entity_namespace = platform_config.get(CONF_ENTITY_NAMESPACE) + key = (platform_type, scan_interval, entity_namespace) + + if key in self._platforms: + entity_platform = self._platforms[key] + else: + self._platforms[key] = EntityPlatform(self, scan_interval, + entity_namespace) + try: - platform.setup_platform( - self.hass, platform_config, - EntityPlatform(self, scan_interval, - entity_namespace).add_entities, - discovery_info) + platform.setup_platform(self.hass, platform_config, + entity_platform.add_entities, + discovery_info) self.hass.config.components.append( '{}.{}'.format(self.domain, platform_type)) @@ -135,6 +142,22 @@ def update_group(self): if self.group is not None: self.group.update_tracked_entity_ids(self.entities.keys()) + def reset(self): + """Remove entities and reset the entity component to initial values.""" + with self.lock: + for platform in self._platforms.values(): + platform.reset() + + self._platforms = { + 'core': self._platforms['core'] + } + self.entities = {} + self.config = None + + if self.group is not None: + self.group.stop() + self.group = None + class EntityPlatform(object): """Keep track of entities for a single platform.""" @@ -146,7 +169,7 @@ def __init__(self, component, scan_interval, entity_namespace): self.scan_interval = scan_interval self.entity_namespace = entity_namespace self.platform_entities = [] - self.is_polling = False + self._unsub_polling = None def add_entities(self, new_entities): """Add entities for a single platform.""" @@ -157,17 +180,23 @@ def add_entities(self, new_entities): self.component.update_group() - if self.is_polling or \ + if self._unsub_polling is not None or \ not any(entity.should_poll for entity in self.platform_entities): return - self.is_polling = True - - track_utc_time_change( + self._unsub_polling = track_utc_time_change( self.component.hass, self._update_entity_states, second=range(0, 60, self.scan_interval)) + def reset(self): + """Remove all entities and reset data.""" + for entity in self.platform_entities: + entity.remove() + if self._unsub_polling is not None: + self._unsub_polling() + self._unsub_polling = None + def _update_entity_states(self, now): """Update the states of all the polling entities.""" with self.component.lock: diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index 744bd0becfb36..89175f32f491e 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -5,6 +5,7 @@ from homeassistant.bootstrap import _setup_component import homeassistant.components.automation as automation from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.exceptions import HomeAssistantError import homeassistant.util.dt as dt_util from tests.common import get_test_home_assistant @@ -414,3 +415,132 @@ def test_services(self): automation.turn_on(self.hass, entity_id) self.hass.pool.block_till_done() assert automation.is_on(self.hass, entity_id) + + @patch('homeassistant.config.load_yaml_config_file', return_value={ + automation.DOMAIN: { + 'alias': 'bye', + 'trigger': { + 'platform': 'event', + 'event_type': 'test_event2', + }, + 'action': { + 'service': 'test.automation', + 'data_template': { + 'event': '{{ trigger.event.event_type }}' + } + } + } + }) + def test_reload_config_service(self, mock_load_yaml): + """Test the reload config service.""" + assert _setup_component(self.hass, automation.DOMAIN, { + automation.DOMAIN: { + 'alias': 'hello', + 'trigger': { + 'platform': 'event', + 'event_type': 'test_event', + }, + 'action': { + 'service': 'test.automation', + 'data_template': { + 'event': '{{ trigger.event.event_type }}' + } + } + } + }) + assert self.hass.states.get('automation.hello') is not None + assert self.hass.states.get('automation.bye') is None + + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + + assert len(self.calls) == 1 + assert self.calls[0].data.get('event') == 'test_event' + + automation.reload(self.hass) + self.hass.pool.block_till_done() + + assert self.hass.states.get('automation.hello') is None + assert self.hass.states.get('automation.bye') is not None + + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + assert len(self.calls) == 1 + + self.hass.bus.fire('test_event2') + self.hass.pool.block_till_done() + assert len(self.calls) == 2 + assert self.calls[1].data.get('event') == 'test_event2' + + @patch('homeassistant.config.load_yaml_config_file', return_value={ + automation.DOMAIN: 'not valid', + }) + def test_reload_config_when_invalid_config(self, mock_load_yaml): + """Test the reload config service handling invalid config.""" + assert _setup_component(self.hass, automation.DOMAIN, { + automation.DOMAIN: { + 'alias': 'hello', + 'trigger': { + 'platform': 'event', + 'event_type': 'test_event', + }, + 'action': { + 'service': 'test.automation', + 'data_template': { + 'event': '{{ trigger.event.event_type }}' + } + } + } + }) + assert self.hass.states.get('automation.hello') is not None + + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + + assert len(self.calls) == 1 + assert self.calls[0].data.get('event') == 'test_event' + + automation.reload(self.hass) + self.hass.pool.block_till_done() + + assert self.hass.states.get('automation.hello') is not None + + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + assert len(self.calls) == 2 + + @patch('homeassistant.config.load_yaml_config_file', + side_effect=HomeAssistantError('bla')) + def test_reload_config_handles_load_fails(self, mock_load_yaml): + """Test the reload config service.""" + assert _setup_component(self.hass, automation.DOMAIN, { + automation.DOMAIN: { + 'alias': 'hello', + 'trigger': { + 'platform': 'event', + 'event_type': 'test_event', + }, + 'action': { + 'service': 'test.automation', + 'data_template': { + 'event': '{{ trigger.event.event_type }}' + } + } + } + }) + assert self.hass.states.get('automation.hello') is not None + + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + + assert len(self.calls) == 1 + assert self.calls[0].data.get('event') == 'test_event' + + automation.reload(self.hass) + self.hass.pool.block_till_done() + + assert self.hass.states.get('automation.hello') is not None + + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + assert len(self.calls) == 2