Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 83 additions & 27 deletions hass_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
import logging
import os
import pprint
import re
from asyncio import CancelledError
from collections.abc import Callable
from types import TracebackType
from typing import Any

import aiohttp
from aiohttp import (
ClientSession,
ClientWebSocketResponse,
Expand All @@ -28,7 +31,7 @@
ConnectionFailed,
FailedCommand,
InvalidMessage,
NotConnected,
NotConnected, ConnectionFailedDueToLargeMessage,
)
from .models import (
Area,
Expand Down Expand Up @@ -61,7 +64,6 @@
EntityChangedCallback = Callable[[EntityStateEvent], None]
SubscriptionCallback = Callable[[Message], None]


class HomeAssistantClient:
"""Connection to HomeAssistant (over websockets)."""

Expand Down Expand Up @@ -93,6 +95,12 @@ def __init__(
self._shutdown_complete_event: asyncio.Event | None = None
self._msg_id_lock = asyncio.Lock()

# Keep track of the maximum message size
self._max_msg_size = 4 * 1024 * 1024

# Event object for efficient reconnection waiting
self._connected_event = asyncio.Event()

@property
def connected(self) -> bool:
"""Return if we're currently connected."""
Expand All @@ -103,6 +111,9 @@ def version(self) -> str:
"""Return version of connected Home Assistant instance."""
return self._version

async def wait_for_connection(self):
await self._connected_event.wait()

async def subscribe_events(
self, cb_func: Callable[[Event], None], event_type: str = MATCH_ALL
) -> Callable:
Expand Down Expand Up @@ -170,35 +181,44 @@ async def call_service(
msg["service_data"] = service_data
if target:
msg["target"] = target
return await self.send_command(msg)
return await self.send_retryable_command(msg)

async def get_states(self) -> list[State]:
"""Get dump of the current states within Home Assistant."""
return await self.send_command("get_states")
return await self.send_retryable_command("get_states")

async def get_config(self) -> list[Config]:
"""Get dump of the current config in Home Assistant."""
return await self.send_command("get_states")
return await self.send_retryable_command("get_config")

async def get_services(self) -> dict[str, dict[str, Any]]:
"""Get dump of the current services in Home Assistant."""
return await self.send_command("get_services")
return await self.send_retryable_command("get_services")

async def get_area_registry(self) -> list[Area]:
"""Get Area Registry."""
return await self.send_command("config/area_registry/list")
return await self.send_retryable_command("config/area_registry/list")

async def get_device_registry(self) -> list[Device]:
"""Get Device Registry."""
return await self.send_command("config/device_registry/list")
return await self.send_retryable_command("config/device_registry/list")

async def get_entity_registry(self) -> list[Entity]:
"""Get Entity Registry."""
return await self.send_command("config/entity_registry/list")
return await self.send_retryable_command("config/entity_registry/list")

async def get_entity_registry_entry(self, entity_id: str) -> Entity:
"""Get single entry from Entity Registry."""
return await self.send_command("config/entity_registry/get", entity_id=entity_id)
return await self.send_retryable_command("config/entity_registry/get", entity_id=entity_id)

async def send_retryable_command(self, command: str, **kwargs: dict[str, Any]) -> CommandResultData:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I'm a fan of this pattern. Why not introduce an argument to specify if the command may be retried ?
Also, I think it would be good to specify a number of retries. So for example introduce an argument "retries:int =3" and thus by default retry a command 3 times on common connection errors

"""Send a command to the HA websocket and return response. Retry on failure."""
while True:
try:
return await self.send_command(command, **kwargs)
except ConnectionFailedDueToLargeMessage:
LOGGER.debug("Connection failed due to large message - waiting for reconnect and then retrying")
await self.wait_for_connection()

async def send_command(self, command: str, **kwargs: dict[str, Any]) -> CommandResultData:
"""Send a command to the HA websocket and return response."""
Expand All @@ -212,6 +232,10 @@ async def send_command(self, command: str, **kwargs: dict[str, Any]) -> CommandR
await self._send_json_message(message)
try:
return await future
except CancelledError as e:
if len(e.args) > 0:
# Raise the inner exception
raise e.args[0] from e
Comment on lines +236 to +238
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this ?

It is not very common to catch the CancelledError, why are you doing this like this ?

Copy link
Copy Markdown
Author

@wrouesnel wrouesnel Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The websocket connection shuts down when a TooLarge error message is returned, which means we experience the failure as a cancellation. But we need to know what actually caused it to shutdown, which is packaged as the first parameter of e.args - since from SendCommand we'd like to get the actual exception, not just "everything was aborted".

The relevant usage is here:

    async def send_retryable_command(self, command: str, **kwargs: dict[str, Any]) -> CommandResultData:
        """Send a command to the HA websocket and return response. Retry on failure."""
        while True:
            try:
                return await self.send_command(command, **kwargs)
            except ConnectionFailedDueToLargeMessage:
                LOGGER.debug("Connection failed due to large message - waiting for reconnect and then retrying")
                await self.wait_for_connection()

where we want to either keep raising the exception (because it's not something known harmless), or handle it because we actually know what it is (ConnectionFailedDueToLargeMessage). Otherwise calling code doesn't know why the connection was ended.

Of note: this catch could be other things, but I don't know what they should be at the moment (a reasonable one would be something like "ConnectionLostRetriesExceeded" or something.)

finally:
self._result_futures.pop(message_id)

Expand Down Expand Up @@ -263,7 +287,8 @@ async def connect(self) -> None:
ws_token = self._token or os.environ.get("HASSIO_TOKEN")
LOGGER.debug("Connecting to Home Assistant Websocket API on %s", ws_url)
try:
self._client = await self._http_session.ws_connect(ws_url, heartbeat=55)
self._client = await self._http_session.ws_connect(ws_url, heartbeat=55,
max_msg_size=self._max_msg_size)
version_msg: AuthRequiredMessage = await self._client.receive_json()
self._version = version_msg["ha_version"]
# send authentication
Expand All @@ -285,6 +310,16 @@ async def connect(self) -> None:
)
# start task to handle incoming messages
self._loop.create_task(self._process_messages())
# notify watchers we're connected
self._connected_event.set()

async def _close_client(self) -> None:
"""Invoke the underlying client close operation and clear the connected state"""
# Block any new users from sending messages - close has been called
self._connected_event.clear()
if not self._client.closed:
await self._client.close()


async def disconnect(self) -> None:
"""Disconnect the client."""
Expand All @@ -294,11 +329,12 @@ async def disconnect(self) -> None:
return

self._shutdown_complete_event = asyncio.Event()
await self._client.close()
await self._close_client()
await self._shutdown_complete_event.wait()

async def _process_messages(self) -> None:
"""Start listening to the websocket."""
terminating_exception = None
try:
while not self._client.closed:
msg = await self._client.receive()
Expand All @@ -307,6 +343,18 @@ async def _process_messages(self) -> None:
break

if msg.type == WSMsgType.ERROR:
# Home Assistant can produce some *very* large messages, and there's
# no sign of a chunking API turning up soon. So check if we're losing
# the connection due to message size.
if msg.data.code == aiohttp.WSCloseCode.MESSAGE_TOO_BIG:
# Parse the attempted size out, and schedule a reconnect.
if (m := re.match(r"Message size (\d+)", msg.data.args[1])) is not None:
attempted_message_size = int(m.group(1))
# Set to 2x what they attempted to send us so hopefully we'll succeed
# on reconnect.
self._max_msg_size = attempted_message_size * 2
raise ConnectionFailedDueToLargeMessage()

Comment on lines +349 to +357
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you have a very specific usecase here, it is not very common to have THIS MANY entities.
So my suggestion would be to just increase the default message limit for now and keep it simple.

Copy link
Copy Markdown
Author

@wrouesnel wrouesnel Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The theory is that the user is already aware, or doesn't care - which for me is true because my HA installation operates just fine (hence how I discovered this issue - turning on iBeacons probably did it and I've not cleaned up yet).

So the logic is that if HA doesn't hit this (and it doesn't seem to have any limit other then browser memory) then the user probably doesn't care either for any practical case because this is HA-specific.

Basically: if the normal HA client isn't breaking (which is much heavier on memory) then the Python client shouldn't break either - it would Just Work(TM).

The other side of that is an obvious reason to use this library would be to do something like scripting a bulk clean up entities as well.

EDIT: Basically I'd argue usability wise, the constraint goes the other way - if reconnects are automatic, then updating receive size limits should be automatic too unless the the user specifically has a reason they want to constrain it.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, although I'm a bit worried about the readability/maintainability of this client with this added overhead.
If you adjust the PR so its compatible with the latest refactor, I'm fine merging it.

Another, more simple approach could be to just have a config param to set the default message size for more advanced setups.

raise ConnectionFailed()

if msg.type != WSMsgType.TEXT:
Expand All @@ -321,20 +369,26 @@ async def _process_messages(self) -> None:
LOGGER.debug("Received message:\n%s\n", pprint.pformat(msg))

self._handle_incoming_message(data)
except Exception as e:
terminating_exception = e

finally:
LOGGER.debug("Listen completed. Cleaning up")
if terminating_exception is not None:
LOGGER.debug("Listen finished with exception - cancelling futures with exception name: %s",
type(terminating_exception).__name__)
else:
LOGGER.debug("Listen completing normally")

for future in self._result_futures.values():
future.cancel()
LOGGER.debug("Listen completed. Cleaning up")
await self._close_client()

if not self._client.closed:
await self._client.close()
LOGGER.debug("Connection closed - cancelling futures")
for future in self._result_futures.values():
future.cancel(msg=terminating_exception)

if self._shutdown_complete_event:
self._shutdown_complete_event.set()
else:
self._on_connection_lost()
if self._shutdown_complete_event:
self._shutdown_complete_event.set()
else:
self._on_connection_lost()

def _handle_incoming_message(self, msg: Message) -> None:
"""Handle incoming message."""
Expand Down Expand Up @@ -407,12 +461,8 @@ async def auto_reconnect():
attempts = 0
sleep_time = 2
while True:
# Try to reconnect right away in case this is recoverable immediately...
attempts += 1
if attempts > 20:
sleep_time = 60
elif sleep_time > 10:
sleep_time = 10
await asyncio.sleep(sleep_time)
try:
await self.connect()
# resubscribe all subscriptions
Expand All @@ -424,6 +474,12 @@ async def auto_reconnect():
return
except CannotConnect:
pass
# Failed, go to sleep now...
if attempts > 20:
sleep_time = 60
elif sleep_time > 10:
sleep_time = 10
await asyncio.sleep(sleep_time)
if attempts >= 30:
LOGGER.warning(
"Still could not reconnect after %s attempts, is the server alive ?",
Expand Down
2 changes: 2 additions & 0 deletions hass_client/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def __init__(self, error: Exception | None = None) -> None:
return
super().__init__(f"{error}", error)

class ConnectionFailedDueToLargeMessage(ConnectionFailed):
"""Exception raised when an established connection fails due to an oversize message"""

class NotFoundError(BaseHassClientError):
"""Exception that is raised when an entity can't be found."""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools~=62.3", "wheel~=0.37.1"]
requires = ["setuptools~=64.0", "wheel~=0.37.1"]
build-backend = "setuptools.build_meta"

[project]
Expand Down