Skip to content

Commit

Permalink
Update device_traker for async platforms (home-assistant#5102)
Browse files Browse the repository at this point in the history
Async DeviceScanner object, migrate to async, cleanups
  • Loading branch information
pvizeli authored and kellerza committed Jan 2, 2017
1 parent 9c6a985 commit b2371c6
Show file tree
Hide file tree
Showing 26 changed files with 124 additions and 72 deletions.
71 changes: 56 additions & 15 deletions homeassistant/components/device_tracker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import timedelta
import logging
import os
from typing import Any, Sequence, Callable
from typing import Any, List, Sequence, Callable

import aiohttp
import async_timeout
Expand Down Expand Up @@ -142,23 +142,34 @@ def async_setup_platform(p_type, p_config, disc_info=None):
if platform is None:
return

_LOGGER.info("Setting up %s.%s", DOMAIN, p_type)
try:
if hasattr(platform, 'get_scanner'):
scanner = None
setup = None
if hasattr(platform, 'async_get_scanner'):
scanner = yield from platform.async_get_scanner(
hass, {DOMAIN: p_config})
elif hasattr(platform, 'get_scanner'):
scanner = yield from hass.loop.run_in_executor(
None, platform.get_scanner, hass, {DOMAIN: p_config})
elif hasattr(platform, 'async_setup_scanner'):
setup = yield from platform.setup_scanner(
hass, p_config, tracker.see)
elif hasattr(platform, 'setup_scanner'):
setup = yield from hass.loop.run_in_executor(
None, platform.setup_scanner, hass, p_config, tracker.see)
else:
raise HomeAssistantError("Invalid device_tracker platform.")

if scanner is None:
_LOGGER.error('Error setting up platform %s', p_type)
return

if scanner:
yield from async_setup_scanner_platform(
hass, p_config, scanner, tracker.async_see)
return

ret = yield from hass.loop.run_in_executor(
None, platform.setup_scanner, hass, p_config, tracker.see)
if not ret:
if not setup:
_LOGGER.error('Error setting up platform %s', p_type)
return

except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error setting up platform %s', p_type)

Expand Down Expand Up @@ -526,6 +537,34 @@ def get_vendor_for_mac(self):
yield from resp.release()


class DeviceScanner(object):
"""Device scanner object."""

hass = None # type: HomeAssistantType

def scan_devices(self) -> List[str]:
"""Scan for devices."""
raise NotImplementedError()

def async_scan_devices(self) -> Any:
"""Scan for devices.
This method must be run in the event loop and returns a coroutine.
"""
return self.hass.loop.run_in_executor(None, self.scan_devices)

def get_device_name(self, mac: str) -> str:
"""Get device name from mac."""
raise NotImplementedError()

def async_get_device_name(self, mac: str) -> Any:
"""Get device name from mac.
This method must be run in the event loop and returns a coroutine.
"""
return self.hass.loop.run_in_executor(None, self.get_device_name, mac)


def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
"""Load devices from YAML configuration file."""
return run_coroutine_threadsafe(
Expand Down Expand Up @@ -582,26 +621,28 @@ def async_setup_scanner_platform(hass: HomeAssistantType, config: ConfigType,
This method is a coroutine.
"""
interval = config.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL)
scanner.hass = hass

# Initial scan of each mac we also tell about host name for config
seen = set() # type: Any

def device_tracker_scan(now: dt_util.dt.datetime):
@asyncio.coroutine
def async_device_tracker_scan(now: dt_util.dt.datetime):
"""Called when interval matches."""
found_devices = scanner.scan_devices()
found_devices = yield from scanner.async_scan_devices()

for mac in found_devices:
if mac in seen:
host_name = None
else:
host_name = scanner.get_device_name(mac)
host_name = yield from scanner.async_get_device_name(mac)
seen.add(mac)
hass.add_job(async_see_device(mac=mac, host_name=host_name))
hass.async_add_job(async_see_device(mac=mac, host_name=host_name))

async_track_utc_time_change(
hass, device_tracker_scan, second=range(0, 60, interval))
hass, async_device_tracker_scan, second=range(0, 60, interval))

hass.async_add_job(device_tracker_scan, None)
hass.async_add_job(async_device_tracker_scan, None)


def update_config(path: str, dev_id: str, device: Device):
Expand Down
5 changes: 3 additions & 2 deletions homeassistant/components/device_tracker/actiontec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import homeassistant.helpers.config_validation as cv
import homeassistant.util.dt as dt_util
from homeassistant.components.device_tracker import (DOMAIN, PLATFORM_SCHEMA)
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle

Expand Down Expand Up @@ -46,7 +47,7 @@ def get_scanner(hass, config):
Device = namedtuple("Device", ["mac", "ip", "last_update"])


class ActiontecDeviceScanner(object):
class ActiontecDeviceScanner(DeviceScanner):
"""This class queries a an actiontec router for connected devices."""

def __init__(self, config):
Expand Down
5 changes: 3 additions & 2 deletions homeassistant/components/device_tracker/aruba.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import voluptuous as vol

import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle

Expand Down Expand Up @@ -42,7 +43,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None


class ArubaDeviceScanner(object):
class ArubaDeviceScanner(DeviceScanner):
"""This class queries a Aruba Access Point for connected devices."""

def __init__(self, config):
Expand Down
5 changes: 3 additions & 2 deletions homeassistant/components/device_tracker/asuswrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import voluptuous as vol

from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle
import homeassistant.helpers.config_validation as cv
Expand Down Expand Up @@ -97,7 +98,7 @@ def get_scanner(hass, config):
AsusWrtResult = namedtuple('AsusWrtResult', 'neighbors leases arp nvram')


class AsusWrtDeviceScanner(object):
class AsusWrtDeviceScanner(DeviceScanner):
"""This class queries a router running ASUSWRT firmware."""

# Eighth attribute needed for mode (AP mode vs router mode)
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/device_tracker/automatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

import voluptuous as vol

from homeassistant.components.device_tracker import (PLATFORM_SCHEMA,
ATTR_ATTRIBUTES)
from homeassistant.components.device_tracker import (
PLATFORM_SCHEMA, ATTR_ATTRIBUTES)
from homeassistant.const import CONF_USERNAME, CONF_PASSWORD
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import track_utc_time_change
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/device_tracker/bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datetime import timedelta

import homeassistant.util.dt as dt_util
from homeassistant.components.device_tracker import DOMAIN
from homeassistant.components.device_tracker import DOMAIN, DeviceScanner
from homeassistant.util import Throttle

REQUIREMENTS = ['pybbox==0.0.5-alpha']
Expand All @@ -29,7 +29,7 @@ def get_scanner(hass, config):
Device = namedtuple('Device', ['mac', 'name', 'ip', 'last_update'])


class BboxDeviceScanner(object):
class BboxDeviceScanner(DeviceScanner):
"""This class scans for devices connected to the bbox."""

def __init__(self, config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,8 @@
import voluptuous as vol
from homeassistant.helpers.event import track_point_in_utc_time
from homeassistant.components.device_tracker import (
YAML_DEVICES,
CONF_TRACK_NEW,
CONF_SCAN_INTERVAL,
DEFAULT_SCAN_INTERVAL,
PLATFORM_SCHEMA,
load_config,
DEFAULT_TRACK_NEW
YAML_DEVICES, CONF_TRACK_NEW, CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL,
PLATFORM_SCHEMA, load_config, DEFAULT_TRACK_NEW
)
import homeassistant.util as util
import homeassistant.util.dt as dt_util
Expand Down
5 changes: 3 additions & 2 deletions homeassistant/components/device_tracker/bt_home_hub_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import voluptuous as vol

import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST
from homeassistant.util import Throttle

Expand All @@ -40,7 +41,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None


class BTHomeHub5DeviceScanner(object):
class BTHomeHub5DeviceScanner(DeviceScanner):
"""This class queries a BT Home Hub 5."""

def __init__(self, config):
Expand Down
5 changes: 3 additions & 2 deletions homeassistant/components/device_tracker/cisco_ios.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import voluptuous as vol

import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME, \
CONF_PORT
from homeassistant.util import Throttle
Expand Down Expand Up @@ -39,7 +40,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None


class CiscoDeviceScanner(object):
class CiscoDeviceScanner(DeviceScanner):
"""This class queries a wireless router running Cisco IOS firmware."""

def __init__(self, config):
Expand Down
5 changes: 3 additions & 2 deletions homeassistant/components/device_tracker/ddwrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import voluptuous as vol

import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle

Expand Down Expand Up @@ -41,7 +42,7 @@ def get_scanner(hass, config):
return None


class DdWrtDeviceScanner(object):
class DdWrtDeviceScanner(DeviceScanner):
"""This class queries a wireless router running DD-WRT firmware."""

def __init__(self, config):
Expand Down
5 changes: 3 additions & 2 deletions homeassistant/components/device_tracker/fritz.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import voluptuous as vol

import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle

Expand Down Expand Up @@ -38,7 +39,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None


class FritzBoxScanner(object):
class FritzBoxScanner(DeviceScanner):
"""This class queries a FRITZ!Box router."""

def __init__(self, config):
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/device_tracker/icloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from homeassistant.const import CONF_USERNAME, CONF_PASSWORD
from homeassistant.components.device_tracker import (
PLATFORM_SCHEMA, DOMAIN, ATTR_ATTRIBUTES, ENTITY_ID_FORMAT)
PLATFORM_SCHEMA, DOMAIN, ATTR_ATTRIBUTES, ENTITY_ID_FORMAT, DeviceScanner)
from homeassistant.components.zone import active_zone
from homeassistant.helpers.event import track_utc_time_change
import homeassistant.helpers.config_validation as cv
Expand Down Expand Up @@ -131,7 +131,7 @@ def setinterval(call):
return True


class Icloud(object):
class Icloud(DeviceScanner):
"""Represent an icloud account in Home Assistant."""

def __init__(self, hass, username, password, name, see):
Expand Down
5 changes: 2 additions & 3 deletions homeassistant/components/device_tracker/locative.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from functools import partial
import logging

from homeassistant.const import (ATTR_LATITUDE, ATTR_LONGITUDE,
STATE_NOT_HOME,
HTTP_UNPROCESSABLE_ENTITY)
from homeassistant.const import (
ATTR_LATITUDE, ATTR_LONGITUDE, STATE_NOT_HOME, HTTP_UNPROCESSABLE_ENTITY)
from homeassistant.components.http import HomeAssistantView
# pylint: disable=unused-import
from homeassistant.components.device_tracker import ( # NOQA
Expand Down
5 changes: 3 additions & 2 deletions homeassistant/components/device_tracker/luci.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import voluptuous as vol

import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle

Expand All @@ -37,7 +38,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None


class LuciDeviceScanner(object):
class LuciDeviceScanner(DeviceScanner):
"""This class queries a wireless router running OpenWrt firmware.
Adapted from Tomato scanner.
Expand Down
5 changes: 3 additions & 2 deletions homeassistant/components/device_tracker/netgear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import voluptuous as vol

import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import (
CONF_HOST, CONF_PASSWORD, CONF_USERNAME, CONF_PORT)
from homeassistant.util import Throttle
Expand Down Expand Up @@ -47,7 +48,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None


class NetgearDeviceScanner(object):
class NetgearDeviceScanner(DeviceScanner):
"""Queries a Netgear wireless router using the SOAP-API."""

def __init__(self, host, username, password, port):
Expand Down
5 changes: 3 additions & 2 deletions homeassistant/components/device_tracker/nmap_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import homeassistant.helpers.config_validation as cv
import homeassistant.util.dt as dt_util
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOSTS
from homeassistant.util import Throttle

Expand Down Expand Up @@ -63,7 +64,7 @@ def _arp(ip_address):
return None


class NmapDeviceScanner(object):
class NmapDeviceScanner(DeviceScanner):
"""This class scans for devices using nmap."""

exclude = []
Expand Down
Loading

0 comments on commit b2371c6

Please sign in to comment.