Skip to content

Commit

Permalink
Handle removal of accessories/services/chars in homekit_controller (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Oct 18, 2023
1 parent e2e9c84 commit c3d1db5
Show file tree
Hide file tree
Showing 11 changed files with 1,907 additions and 65 deletions.
73 changes: 50 additions & 23 deletions homeassistant/components/homekit_controller/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.debounce import Debouncer
from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.event import async_call_later, async_track_time_interval

from .config_flow import normalize_hkid
from .const import (
Expand All @@ -43,6 +43,7 @@
IDENTIFIER_LEGACY_SERIAL_NUMBER,
IDENTIFIER_SERIAL_NUMBER,
STARTUP_EXCEPTIONS,
SUBSCRIBE_COOLDOWN,
)
from .device_trigger import async_fire_triggers, async_setup_triggers_for_entry

Expand Down Expand Up @@ -116,15 +117,15 @@ def __init__(

# This just tracks aid/iid pairs so we know if a HK service has been
# mapped to a HA entity.
self.entities: list[tuple[int, int | None, int | None]] = []
self.entities: set[tuple[int, int | None, int | None]] = set()

# A map of aid -> device_id
# Useful when routing events to triggers
self.devices: dict[int, str] = {}

self.available = False

self.pollable_characteristics: list[tuple[int, int]] = []
self.pollable_characteristics: set[tuple[int, int]] = set()

# Never allow concurrent polling of the same accessory or bridge
self._polling_lock = asyncio.Lock()
Expand All @@ -134,7 +135,7 @@ def __init__(
# This is set to True if we can't rely on serial numbers to be unique
self.unreliable_serial_numbers = False

self.watchable_characteristics: list[tuple[int, int]] = []
self.watchable_characteristics: set[tuple[int, int]] = set()

self._debounced_update = Debouncer(
hass,
Expand All @@ -147,6 +148,8 @@ def __init__(
self._availability_callbacks: set[CALLBACK_TYPE] = set()
self._config_changed_callbacks: set[CALLBACK_TYPE] = set()
self._subscriptions: dict[tuple[int, int], set[CALLBACK_TYPE]] = {}
self._pending_subscribes: set[tuple[int, int]] = set()
self._subscribe_timer: CALLBACK_TYPE | None = None

@property
def entity_map(self) -> Accessories:
Expand All @@ -162,26 +165,51 @@ def add_pollable_characteristics(
self, characteristics: list[tuple[int, int]]
) -> None:
"""Add (aid, iid) pairs that we need to poll."""
self.pollable_characteristics.extend(characteristics)
self.pollable_characteristics.update(characteristics)

def remove_pollable_characteristics(self, accessory_id: int) -> None:
def remove_pollable_characteristics(
self, characteristics: list[tuple[int, int]]
) -> None:
"""Remove all pollable characteristics by accessory id."""
self.pollable_characteristics = [
char for char in self.pollable_characteristics if char[0] != accessory_id
]
for aid_iid in characteristics:
self.pollable_characteristics.discard(aid_iid)

async def add_watchable_characteristics(
def add_watchable_characteristics(
self, characteristics: list[tuple[int, int]]
) -> None:
"""Add (aid, iid) pairs that we need to poll."""
self.watchable_characteristics.extend(characteristics)
await self.pairing.subscribe(characteristics)
self.watchable_characteristics.update(characteristics)
self._pending_subscribes.update(characteristics)
# Try to subscribe to the characteristics all at once
if not self._subscribe_timer:
self._subscribe_timer = async_call_later(
self.hass,
SUBSCRIBE_COOLDOWN,
self._async_subscribe,
)

def remove_watchable_characteristics(self, accessory_id: int) -> None:
@callback
def _async_cancel_subscription_timer(self) -> None:
"""Cancel the subscribe timer."""
if self._subscribe_timer:
self._subscribe_timer()
self._subscribe_timer = None

async def _async_subscribe(self, _now: datetime) -> None:
"""Subscribe to characteristics."""
self._subscribe_timer = None
if self._pending_subscribes:
subscribes = self._pending_subscribes.copy()
self._pending_subscribes.clear()
await self.pairing.subscribe(subscribes)

def remove_watchable_characteristics(
self, characteristics: list[tuple[int, int]]
) -> None:
"""Remove all pollable characteristics by accessory id."""
self.watchable_characteristics = [
char for char in self.watchable_characteristics if char[0] != accessory_id
]
for aid_iid in characteristics:
self.watchable_characteristics.discard(aid_iid)
self._pending_subscribes.discard(aid_iid)

@callback
def async_set_available_state(self, available: bool) -> None:
Expand Down Expand Up @@ -264,6 +292,7 @@ async def async_setup(self) -> None:
entry.async_on_unload(
pairing.dispatcher_availability_changed(self.async_set_available_state)
)
entry.async_on_unload(self._async_cancel_subscription_timer)

await self.async_process_entity_map()

Expand Down Expand Up @@ -605,8 +634,6 @@ def process_config_changed(self, config_num: int) -> None:
async def async_update_new_accessories_state(self) -> None:
"""Process a change in the pairings accessories state."""
await self.async_process_entity_map()
if self.watchable_characteristics:
await self.pairing.subscribe(self.watchable_characteristics)
for callback_ in self._config_changed_callbacks:
callback_()
await self.async_update()
Expand All @@ -623,7 +650,7 @@ def _add_new_entities_for_accessory(self, handlers) -> None:
if (accessory.aid, None, None) in self.entities:
continue
if handler(accessory):
self.entities.append((accessory.aid, None, None))
self.entities.add((accessory.aid, None, None))
break

def add_char_factory(self, add_entities_cb: AddCharacteristicCb) -> None:
Expand All @@ -639,7 +666,7 @@ def _add_new_entities_for_char(self, handlers) -> None:
if (accessory.aid, service.iid, char.iid) in self.entities:
continue
if handler(char):
self.entities.append((accessory.aid, service.iid, char.iid))
self.entities.add((accessory.aid, service.iid, char.iid))
break

def add_listener(self, add_entities_cb: AddServiceCb) -> None:
Expand Down Expand Up @@ -687,7 +714,7 @@ def _add_new_entities(self, callbacks) -> None:

for listener in callbacks:
if listener(service):
self.entities.append((aid, None, iid))
self.entities.add((aid, None, iid))
break

async def async_load_platform(self, platform: str) -> None:
Expand Down Expand Up @@ -811,7 +838,7 @@ def process_new_events(

@callback
def _remove_characteristics_callback(
self, characteristics: Iterable[tuple[int, int]], callback_: CALLBACK_TYPE
self, characteristics: set[tuple[int, int]], callback_: CALLBACK_TYPE
) -> None:
"""Remove a characteristics callback."""
for aid_iid in characteristics:
Expand All @@ -821,7 +848,7 @@ def _remove_characteristics_callback(

@callback
def async_subscribe(
self, characteristics: Iterable[tuple[int, int]], callback_: CALLBACK_TYPE
self, characteristics: set[tuple[int, int]], callback_: CALLBACK_TYPE
) -> CALLBACK_TYPE:
"""Add characteristics to the watch list."""
for aid_iid in characteristics:
Expand Down
2 changes: 2 additions & 0 deletions homeassistant/components/homekit_controller/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,5 @@
# also happens to be the same value used by
# the update coordinator.
DEBOUNCE_COOLDOWN = 10 # seconds

SUBSCRIBE_COOLDOWN = 0.25 # seconds
7 changes: 4 additions & 3 deletions homeassistant/components/homekit_controller/device_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def __init__(self, hass: HomeAssistant) -> None:
self._callbacks: dict[tuple[str, str], list[Callable[[Any], None]]] = {}
self._iid_trigger_keys: dict[int, set[tuple[str, str]]] = {}

async def async_setup(
@callback
def async_setup(
self, connection: HKDevice, aid: int, triggers: list[dict[str, Any]]
) -> None:
"""Set up a set of triggers for a device.
Expand All @@ -78,7 +79,7 @@ async def async_setup(
self._triggers[trigger_key] = trigger_data
iid = trigger_data["characteristic"]
self._iid_trigger_keys.setdefault(iid, set()).add(trigger_key)
await connection.add_watchable_characteristics([(aid, iid)])
connection.add_watchable_characteristics([(aid, iid)])

def fire(self, iid: int, ev: dict[str, Any]) -> None:
"""Process events that have been received from a HomeKit accessory."""
Expand Down Expand Up @@ -237,7 +238,7 @@ def async_add_characteristic(service: Service):
return False

trigger = async_get_or_create_trigger_source(conn.hass, device_id)
hass.async_create_task(trigger.async_setup(conn, aid, triggers))
trigger.async_setup(conn, aid, triggers)

return True

Expand Down
Loading

0 comments on commit c3d1db5

Please sign in to comment.