Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up ZHA initialization and improve startup responsiveness #108103

Merged
merged 6 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
ZHA_CLUSTER_HANDLER_MSG_DATA,
ZHA_CLUSTER_HANDLER_READS_PER_REQ,
)
from ..helpers import LogMixin, retryable_req, safe_read
from ..helpers import LogMixin, safe_read

if TYPE_CHECKING:
from ..endpoint import Endpoint
Expand Down Expand Up @@ -362,7 +362,6 @@ async def async_configure(self) -> None:
self.debug("skipping cluster handler configuration")
self._status = ClusterHandlerStatus.CONFIGURED

@retryable_req(delays=(1, 1, 3))
async def async_initialize(self, from_cache: bool) -> None:
"""Initialize cluster handler."""
if not from_cache and self._endpoint.device.skip_configuration:
Expand Down
17 changes: 11 additions & 6 deletions homeassistant/components/zha/core/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,12 +592,17 @@ async def async_initialize(self, from_cache: bool = False) -> None:
self.debug("started initialization")
await self._zdo_handler.async_initialize(from_cache)
self._zdo_handler.debug("'async_initialize' stage succeeded")
await asyncio.gather(
*(
endpoint.async_initialize(from_cache)
for endpoint in self._endpoints.values()
)
)

puddly marked this conversation as resolved.
Show resolved Hide resolved
# We intentionally do not use `gather` here! This is so that if, for example,
# three `device.async_initialize()`s are spawned, only three concurrent requests
# will ever be in flight at once. Startup concurrency is managed at the device
# level.
for endpoint in self._endpoints.values():
try:
await endpoint.async_initialize(from_cache)
except Exception: # pylint: disable=broad-exception-caught
self.debug("Failed to initialize endpoint", exc_info=True)

self.debug("power source: %s", self.power_source)
self.status = DeviceStatus.INITIALIZED
self.debug("completed initialization")
Expand Down
22 changes: 18 additions & 4 deletions homeassistant/components/zha/core/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from __future__ import annotations

import asyncio
from collections.abc import Callable
from collections.abc import Awaitable, Callable
import functools
import logging
from typing import TYPE_CHECKING, Any, Final, TypeVar

Expand All @@ -11,6 +12,7 @@
from homeassistant.const import Platform
from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.util.async_ import gather_with_limited_concurrency

from . import const, discovery, registries
from .cluster_handlers import ClusterHandler
Expand Down Expand Up @@ -169,20 +171,32 @@ def add_client_cluster_handlers(self) -> None:

async def async_initialize(self, from_cache: bool = False) -> None:
"""Initialize claimed cluster handlers."""
await self._execute_handler_tasks("async_initialize", from_cache)
await self._execute_handler_tasks(
"async_initialize", from_cache, max_concurrency=1
)

async def async_configure(self) -> None:
"""Configure claimed cluster handlers."""
await self._execute_handler_tasks("async_configure")

async def _execute_handler_tasks(self, func_name: str, *args: Any) -> None:
async def _execute_handler_tasks(
self, func_name: str, *args: Any, max_concurrency: int | None = None
) -> None:
"""Add a throttled cluster handler task and swallow exceptions."""
cluster_handlers = [
*self.claimed_cluster_handlers.values(),
*self.client_cluster_handlers.values(),
]
tasks = [getattr(ch, func_name)(*args) for ch in cluster_handlers]
results = await asyncio.gather(*tasks, return_exceptions=True)

gather: Callable[..., Awaitable]

if max_concurrency is None:
gather = asyncio.gather
else:
gather = functools.partial(gather_with_limited_concurrency, max_concurrency)

results = await gather(*tasks, return_exceptions=True)
for cluster_handler, outcome in zip(cluster_handlers, results):
if isinstance(outcome, Exception):
cluster_handler.warning(
Expand Down
49 changes: 37 additions & 12 deletions homeassistant/components/zha/core/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logging
import re
import time
from typing import TYPE_CHECKING, Any, NamedTuple, Self
from typing import TYPE_CHECKING, Any, NamedTuple, Self, cast

from zigpy.application import ControllerApplication
from zigpy.config import (
Expand All @@ -36,6 +36,7 @@
from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.typing import ConfigType
from homeassistant.util.async_ import gather_with_limited_concurrency

from . import discovery
from .const import (
Expand Down Expand Up @@ -292,6 +293,39 @@ def async_load_groups(self) -> None:
# entity registry tied to the devices
discovery.GROUP_PROBE.discover_group_entities(zha_group)

@property
def radio_concurrency(self) -> int:
"""Maximum configured radio concurrency."""
return self.application_controller._concurrent_requests_semaphore.max_value # pylint: disable=protected-access

async def async_fetch_updated_state_mains(self) -> None:
"""Fetch updated state for mains powered devices."""
_LOGGER.debug("Fetching current state for mains powered devices")

now = time.time()

# Only delay startup to poll mains-powered devices that are online
online_devices = [
dev
for dev in self.devices.values()
if dev.is_mains_powered
and dev.last_seen is not None
and (now - dev.last_seen) < dev.consider_unavailable_time
]

# Prioritize devices that have recently been contacted
online_devices.sort(key=lambda dev: cast(float, dev.last_seen), reverse=True)

# Make sure that we always leave slots for non-startup requests
max_poll_concurrency = max(1, self.radio_concurrency - 4)

await gather_with_limited_concurrency(
max_poll_concurrency,
*(dev.async_initialize(from_cache=False) for dev in online_devices),
)

_LOGGER.debug("completed fetching current state for mains powered devices")

async def async_initialize_devices_and_entities(self) -> None:
"""Initialize devices and load entities."""

Expand All @@ -302,17 +336,8 @@ async def async_initialize_devices_and_entities(self) -> None:

async def fetch_updated_state() -> None:
"""Fetch updated state for mains powered devices."""
_LOGGER.debug("Fetching current state for mains powered devices")
await asyncio.gather(
*(
dev.async_initialize(from_cache=False)
for dev in self.devices.values()
if dev.is_mains_powered
)
)
_LOGGER.debug(
"completed fetching current state for mains powered devices - allowing polled requests"
)
await self.async_fetch_updated_state_mains()
_LOGGER.debug("Allowing polled requests")
self.hass.data[DATA_ZHA].allow_polling = True

# background the fetching of state for mains powered devices
Expand Down
60 changes: 2 additions & 58 deletions homeassistant/components/zha/core/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,15 @@
"""
from __future__ import annotations

import asyncio
import binascii
import collections
from collections.abc import Callable, Collection, Coroutine, Iterator
from collections.abc import Callable, Iterator
import dataclasses
from dataclasses import dataclass
import enum
import functools
import itertools
import logging
from random import uniform
import re
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar

import voluptuous as vol
import zigpy.exceptions
Expand Down Expand Up @@ -322,58 +318,6 @@ def error(self, msg, *args, **kwargs):
return self.log(logging.ERROR, msg, *args, **kwargs)


def retryable_req(
delays: Collection[float] = (1, 5, 10, 15, 30, 60, 120, 180, 360, 600, 900, 1800),
raise_: bool = False,
) -> Callable[
[Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R]]],
Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R | None]],
]:
"""Make a method with ZCL requests retryable.

This adds delays keyword argument to function.
len(delays) is number of tries.
raise_ if the final attempt should raise the exception.
"""

def decorator(
func: Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R]],
) -> Callable[Concatenate[_ClusterHandlerT, _P], Coroutine[Any, Any, _R | None]]:
@functools.wraps(func)
async def wrapper(
cluster_handler: _ClusterHandlerT, *args: _P.args, **kwargs: _P.kwargs
) -> _R | None:
exceptions = (zigpy.exceptions.ZigbeeException, asyncio.TimeoutError)
try_count, errors = 1, []
for delay in itertools.chain(delays, [None]):
try:
return await func(cluster_handler, *args, **kwargs)
except exceptions as ex:
errors.append(ex)
if delay:
delay = uniform(delay * 0.75, delay * 1.25)
cluster_handler.debug(
"%s: retryable request #%d failed: %s. Retrying in %ss",
func.__name__,
try_count,
ex,
round(delay, 1),
)
try_count += 1
await asyncio.sleep(delay)
else:
cluster_handler.warning(
"%s: all attempts have failed: %s", func.__name__, errors
)
if raise_:
raise
return None

return wrapper

return decorator


def convert_install_code(value: str) -> bytes:
"""Convert string to install code bytes and validate length."""

Expand Down
2 changes: 1 addition & 1 deletion tests/components/zha/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _wrap_mock_instance(obj: Any) -> MagicMock:
real_attr = getattr(obj, attr_name)
mock_attr = getattr(mock, attr_name)

if callable(real_attr):
if callable(real_attr) and not hasattr(real_attr, "__aenter__"):
mock_attr.side_effect = real_attr
else:
setattr(mock, attr_name, real_attr)
Expand Down
82 changes: 81 additions & 1 deletion tests/components/zha/test_gateway.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Test ZHA Gateway."""
import asyncio
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, PropertyMock, patch

import pytest
from zigpy.application import ControllerApplication
import zigpy.profiles.zha as zha
import zigpy.types
import zigpy.zcl.clusters.general as general
import zigpy.zcl.clusters.lighting as lighting
import zigpy.zdo.types

from homeassistant.components.zha.core.gateway import ZHAGateway
from homeassistant.components.zha.core.group import GroupMember
Expand Down Expand Up @@ -321,3 +323,81 @@ async def test_single_reload_on_multiple_connection_loss(
assert len(mock_reload.mock_calls) == 1

await hass.async_block_till_done()


@pytest.mark.parametrize("radio_concurrency", [1, 2, 8])
async def test_startup_concurrency_limit(
radio_concurrency: int,
hass: HomeAssistant,
zigpy_app_controller: ControllerApplication,
config_entry: MockConfigEntry,
zigpy_device_mock,
):
"""Test ZHA gateway limits concurrency on startup."""
config_entry.add_to_hass(hass)
zha_gateway = ZHAGateway(hass, {}, config_entry)

with patch(
"bellows.zigbee.application.ControllerApplication.new",
return_value=zigpy_app_controller,
):
await zha_gateway.async_initialize()

for i in range(50):
zigpy_dev = zigpy_device_mock(
{
1: {
SIG_EP_INPUT: [
general.OnOff.cluster_id,
general.LevelControl.cluster_id,
lighting.Color.cluster_id,
general.Groups.cluster_id,
],
SIG_EP_OUTPUT: [],
SIG_EP_TYPE: zha.DeviceType.COLOR_DIMMABLE_LIGHT,
SIG_EP_PROFILE: zha.PROFILE_ID,
}
},
ieee=f"11:22:33:44:{i:08x}",
nwk=0x1234 + i,
)
zigpy_dev.node_desc.mac_capability_flags |= (
zigpy.zdo.types.NodeDescriptor.MACCapabilityFlags.MainsPowered
)

zha_gateway._async_get_or_create_device(zigpy_dev, restored=True)

# Keep track of request concurrency during initialization
current_concurrency = 0
concurrencies = []

async def mock_send_packet(*args, **kwargs):
nonlocal current_concurrency

current_concurrency += 1
concurrencies.append(current_concurrency)

await asyncio.sleep(0.001)

current_concurrency -= 1
concurrencies.append(current_concurrency)

type(zha_gateway).radio_concurrency = PropertyMock(return_value=radio_concurrency)
assert zha_gateway.radio_concurrency == radio_concurrency

with patch(
"homeassistant.components.zha.core.device.ZHADevice.async_initialize",
side_effect=mock_send_packet,
):
await zha_gateway.async_fetch_updated_state_mains()

await zha_gateway.shutdown()

# Make sure concurrency was always limited
assert current_concurrency == 0
assert min(concurrencies) == 0

if radio_concurrency > 1:
assert 1 <= max(concurrencies) < zha_gateway.radio_concurrency
else:
assert 1 == max(concurrencies) == zha_gateway.radio_concurrency
Loading