Skip to content

Commit

Permalink
Add template support to Bayesian sensor (home-assistant#20757)
Browse files Browse the repository at this point in the history
* Add template support to Bayesian sensor

* Removed unused import
  • Loading branch information
arsaboo authored and bachya committed Feb 14, 2019
1 parent c20e0b9 commit bf0a50c
Showing 1 changed file with 36 additions and 14 deletions.
50 changes: 36 additions & 14 deletions homeassistant/components/binary_sensor/bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,27 @@
For more details about this platform, please refer to the documentation at
https://home-assistant.io/components/binary_sensor.bayesian/
"""
import logging
from collections import OrderedDict

import voluptuous as vol

import homeassistant.helpers.config_validation as cv
from homeassistant.components.binary_sensor import (
BinarySensorDevice, PLATFORM_SCHEMA)
PLATFORM_SCHEMA, BinarySensorDevice)
from homeassistant.const import (
CONF_ABOVE, CONF_BELOW, CONF_DEVICE_CLASS, CONF_ENTITY_ID, CONF_NAME,
CONF_PLATFORM, CONF_STATE, STATE_UNKNOWN)
CONF_PLATFORM, CONF_STATE, CONF_VALUE_TEMPLATE, STATE_UNKNOWN)
from homeassistant.core import callback
from homeassistant.helpers import condition
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import async_track_state_change

_LOGGER = logging.getLogger(__name__)

ATTR_OBSERVATIONS = 'observations'
ATTR_PROBABILITY = 'probability'
ATTR_PROBABILITY_THRESHOLD = 'probability_threshold'

CONF_OBSERVATIONS = 'observations'
CONF_PRIOR = 'prior'
CONF_TEMPLATE = "template"
CONF_PROBABILITY_THRESHOLD = 'probability_threshold'
CONF_P_GIVEN_F = 'prob_given_false'
CONF_P_GIVEN_T = 'prob_given_true'
Expand All @@ -52,12 +50,20 @@
vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float)
}, required=True)

TEMPLATE_SCHEMA = vol.Schema({
CONF_PLATFORM: CONF_TEMPLATE,
vol.Required(CONF_VALUE_TEMPLATE): cv.template,
vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float)
}, required=True)

PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.Optional(CONF_DEVICE_CLASS): cv.string,
vol.Required(CONF_OBSERVATIONS):
vol.Schema(vol.All(cv.ensure_list,
[vol.Any(NUMERIC_STATE_SCHEMA, STATE_SCHEMA)])),
[vol.Any(NUMERIC_STATE_SCHEMA, STATE_SCHEMA,
TEMPLATE_SCHEMA)])),
vol.Required(CONF_PRIOR): vol.Coerce(float),
vol.Optional(CONF_PROBABILITY_THRESHOLD,
default=DEFAULT_PROBABILITY_THRESHOLD): vol.Coerce(float),
Expand All @@ -68,7 +74,6 @@ def update_probability(prior, prob_true, prob_false):
"""Update probability using Bayes' rule."""
numerator = prob_true * prior
denominator = numerator + prob_false * (1 - prior)

probability = numerator / denominator
return probability

Expand Down Expand Up @@ -104,17 +109,27 @@ def __init__(self, name, prior, observations, probability_threshold,

self.current_obs = OrderedDict({})

to_observe = set(obs['entity_id'] for obs in self._observations)

to_observe = set()
for obs in self._observations:
if 'entity_id' in obs:
to_observe.update(set([obs.get('entity_id')]))
if 'value_template' in obs:
to_observe.update(
set(obs.get(CONF_VALUE_TEMPLATE).extract_entities()))
self.entity_obs = dict.fromkeys(to_observe, [])

for ind, obs in enumerate(self._observations):
obs['id'] = ind
self.entity_obs[obs['entity_id']].append(obs)
if 'entity_id' in obs:
self.entity_obs[obs['entity_id']].append(obs)
if 'value_template' in obs:
for ent in obs.get(CONF_VALUE_TEMPLATE).extract_entities():
self.entity_obs[ent].append(obs)

self.watchers = {
'numeric_state': self._process_numeric_state,
'state': self._process_state
'state': self._process_state,
'template': self._process_template
}

async def async_added_to_hass(self):
Expand All @@ -141,9 +156,8 @@ def async_threshold_sensor_state_listener(entity, old_state,

self.hass.async_add_job(self.async_update_ha_state, True)

entities = [obs['entity_id'] for obs in self._observations]
async_track_state_change(
self.hass, entities, async_threshold_sensor_state_listener)
self.hass, self.entity_obs, async_threshold_sensor_state_listener)

def _update_current_obs(self, entity_observation, should_trigger):
"""Update current observation."""
Expand Down Expand Up @@ -182,6 +196,14 @@ def _process_state(self, entity_observation):

self._update_current_obs(entity_observation, should_trigger)

def _process_template(self, entity_observation):
"""Add entity to current_obs if template is true."""
template = entity_observation.get(CONF_VALUE_TEMPLATE)
template.hass = self.hass
should_trigger = condition.async_template(
self.hass, template, entity_observation)
self._update_current_obs(entity_observation, should_trigger)

@property
def name(self):
"""Return the name of the sensor."""
Expand Down

0 comments on commit bf0a50c

Please sign in to comment.