Skip to content

Commit

Permalink
Allow reloading automation
Browse files Browse the repository at this point in the history
  • Loading branch information
balloob committed Aug 26, 2016
1 parent 3fa1963 commit d231f50
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 41 deletions.
110 changes: 85 additions & 25 deletions homeassistant/components/automation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import voluptuous as vol

from homeassistant.bootstrap import prepare_setup_platform
from homeassistant import config as conf_util
from homeassistant.const import (
ATTR_ENTITY_ID, CONF_PLATFORM, STATE_ON, SERVICE_TURN_ON, SERVICE_TURN_OFF,
SERVICE_TOGGLE)
Expand Down Expand Up @@ -46,6 +47,7 @@
ATTR_LAST_TRIGGERED = 'last_triggered'
ATTR_VARIABLES = 'variables'
SERVICE_TRIGGER = 'trigger'
SERVICE_RELOAD = 'reload'

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -112,6 +114,8 @@ def validator(config):
vol.Optional(ATTR_VARIABLES, default={}): dict,
})

RELOAD_SERVICE_SCHEMA = vol.Schema({})


def is_on(hass, entity_id=None):
"""
Expand Down Expand Up @@ -148,35 +152,16 @@ def trigger(hass, entity_id=None):
hass.services.call(DOMAIN, SERVICE_TRIGGER, data)


def reload(hass):
"""Reload the automation from config."""
hass.services.call(DOMAIN, SERVICE_RELOAD)


def setup(hass, config):
"""Setup the automation."""
component = EntityComponent(_LOGGER, DOMAIN, hass)

success = False
for config_key in extract_domain_configs(config, DOMAIN):
conf = config[config_key]

for list_no, config_block in enumerate(conf):
name = config_block.get(CONF_ALIAS) or "{} {}".format(config_key,
list_no)

action = _get_action(hass, config_block.get(CONF_ACTION, {}), name)

if CONF_CONDITION in config_block:
cond_func = _process_if(hass, config, config_block)

if cond_func is None:
continue
else:
def cond_func(variables):
"""Condition will always pass."""
return True

attach_triggers = partial(_process_trigger, hass, config,
config_block.get(CONF_TRIGGER, []), name)
entity = AutomationEntity(name, attach_triggers, cond_func, action)
component.add_entities((entity,))
success = True
success = _process_config(hass, config, component)

if not success:
return False
Expand All @@ -191,9 +176,47 @@ def service_handler(service_call):
for entity in component.extract_from_service(service_call):
getattr(entity, service_call.service)()

def reload_service_handler(service_call):
"""Remove all automations and load new ones from config."""
try:
path = conf_util.find_config_file(hass.config.config_dir)
conf = conf_util.load_yaml_config_file(path)
except HomeAssistantError as err:
_LOGGER.error(err)
return

# For now copied from bootstrap.py
# Depends on work by @Kellerza to split this out
from homeassistant.bootstrap import config_per_platform, log_exception

platforms = []
for _, p_config in config_per_platform(conf, DOMAIN):
# Validate component specific platform schema
try:
p_validated = PLATFORM_SCHEMA(p_config)
except vol.MultipleInvalid as ex:
log_exception(ex, DOMAIN, p_config)
return

platforms.append(p_validated)

# Create a copy of the configuration with all config for current
# component removed and add validated config back in.
filter_keys = extract_domain_configs(conf, DOMAIN)
conf = {key: value for key, value in conf.items()
if key not in filter_keys}
conf[DOMAIN] = platforms
# End copied from bootstrap

component.reset()
_process_config(hass, conf, component)

hass.services.register(DOMAIN, SERVICE_TRIGGER, trigger_service_handler,
schema=TRIGGER_SERVICE_SCHEMA)

hass.services.register(DOMAIN, SERVICE_RELOAD, reload_service_handler,
schema=RELOAD_SERVICE_SCHEMA)

for service in (SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE):
hass.services.register(DOMAIN, service, service_handler,
schema=SERVICE_SCHEMA)
Expand Down Expand Up @@ -262,6 +285,43 @@ def trigger(self, variables):
self._last_triggered = utcnow()
self.update_ha_state()

def remove(self):
"""Remove automation from HASS."""
self.turn_off()
super().remove()


def _process_config(hass, config, component):
"""Process config and add automations."""
success = False

for config_key in extract_domain_configs(config, DOMAIN):
conf = config[config_key]

for list_no, config_block in enumerate(conf):
name = config_block.get(CONF_ALIAS) or "{} {}".format(config_key,
list_no)

action = _get_action(hass, config_block.get(CONF_ACTION, {}), name)

if CONF_CONDITION in config_block:
cond_func = _process_if(hass, config, config_block)

if cond_func is None:
continue
else:
def cond_func(variables):
"""Condition will always pass."""
return True

attach_triggers = partial(_process_trigger, hass, config,
config_block.get(CONF_TRIGGER, []), name)
entity = AutomationEntity(name, attach_triggers, cond_func, action)
component.add_entities((entity,))
success = True

return success


def _get_action(hass, config, name):
"""Return an action based on a configuration."""
Expand Down
4 changes: 4 additions & 0 deletions homeassistant/helpers/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ def update_ha_state(self, force_refresh=False):
return self.hass.states.set(
self.entity_id, state, attr, self.force_update)

def remove(self) -> None:
"""Remove entitiy from HASS."""
self.hass.states.remove(self.entity_id)

def _attr_setter(self, name, typ, attr, attrs):
"""Helper method to populate attributes based on properties."""
if attr in attrs:
Expand Down
61 changes: 45 additions & 16 deletions homeassistant/helpers/entity_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ def __init__(self, logger, domain, hass,

self.entities = {}
self.group = None
self.is_polling = False

self.config = None
self.lock = Lock()

self.add_entities = EntityPlatform(self, self.scan_interval,
None).add_entities
self._platforms = {
'core': EntityPlatform(self, self.scan_interval, None),
}
self.add_entities = self._platforms['core'].add_entities

def setup(self, config):
"""Set up a full entity component.
Expand Down Expand Up @@ -85,17 +86,23 @@ def _setup_platform(self, platform_type, platform_config,
return

# Config > Platform > Component
scan_interval = platform_config.get(
CONF_SCAN_INTERVAL,
getattr(platform, 'SCAN_INTERVAL', self.scan_interval))
scan_interval = (platform_config.get(CONF_SCAN_INTERVAL) or
getattr(platform, 'SCAN_INTERVAL', None) or
self.scan_interval)
entity_namespace = platform_config.get(CONF_ENTITY_NAMESPACE)

key = (platform_type, scan_interval, entity_namespace)

if key in self._platforms:
entity_platform = self._platforms[key]
else:
self._platforms[key] = EntityPlatform(self, scan_interval,
entity_namespace)

try:
platform.setup_platform(
self.hass, platform_config,
EntityPlatform(self, scan_interval,
entity_namespace).add_entities,
discovery_info)
platform.setup_platform(self.hass, platform_config,
entity_platform.add_entities,
discovery_info)

self.hass.config.components.append(
'{}.{}'.format(self.domain, platform_type))
Expand Down Expand Up @@ -135,6 +142,22 @@ def update_group(self):
if self.group is not None:
self.group.update_tracked_entity_ids(self.entities.keys())

def reset(self):
"""Remove entities and reset the entity component to initial values."""
with self.lock:
for platform in self._platforms.values():
platform.reset()

self._platforms = {
'core': self._platforms['core']
}
self.entities = {}
self.config = None

if self.group is not None:
self.group.stop()
self.group = None


class EntityPlatform(object):
"""Keep track of entities for a single platform."""
Expand All @@ -146,7 +169,7 @@ def __init__(self, component, scan_interval, entity_namespace):
self.scan_interval = scan_interval
self.entity_namespace = entity_namespace
self.platform_entities = []
self.is_polling = False
self._unsub_polling = None

def add_entities(self, new_entities):
"""Add entities for a single platform."""
Expand All @@ -157,17 +180,23 @@ def add_entities(self, new_entities):

self.component.update_group()

if self.is_polling or \
if self._unsub_polling is not None or \
not any(entity.should_poll for entity
in self.platform_entities):
return

self.is_polling = True

track_utc_time_change(
self._unsub_polling = track_utc_time_change(
self.component.hass, self._update_entity_states,
second=range(0, 60, self.scan_interval))

def reset(self):
"""Remove all entities and reset data."""
for entity in self.platform_entities:
entity.remove()
if self._unsub_polling is not None:
self._unsub_polling()
self._unsub_polling = None

def _update_entity_states(self, now):
"""Update the states of all the polling entities."""
with self.component.lock:
Expand Down
Loading

0 comments on commit d231f50

Please sign in to comment.