-
Notifications
You must be signed in to change notification settings - Fork 15
Fix handling of "message too large" with autoscaling #143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,10 +9,13 @@ | |
| import logging | ||
| import os | ||
| import pprint | ||
| import re | ||
| from asyncio import CancelledError | ||
| from collections.abc import Callable | ||
| from types import TracebackType | ||
| from typing import Any | ||
|
|
||
| import aiohttp | ||
| from aiohttp import ( | ||
| ClientSession, | ||
| ClientWebSocketResponse, | ||
|
|
@@ -28,7 +31,7 @@ | |
| ConnectionFailed, | ||
| FailedCommand, | ||
| InvalidMessage, | ||
| NotConnected, | ||
| NotConnected, ConnectionFailedDueToLargeMessage, | ||
| ) | ||
| from .models import ( | ||
| Area, | ||
|
|
@@ -61,7 +64,6 @@ | |
| EntityChangedCallback = Callable[[EntityStateEvent], None] | ||
| SubscriptionCallback = Callable[[Message], None] | ||
|
|
||
|
|
||
| class HomeAssistantClient: | ||
| """Connection to HomeAssistant (over websockets).""" | ||
|
|
||
|
|
@@ -93,6 +95,12 @@ def __init__( | |
| self._shutdown_complete_event: asyncio.Event | None = None | ||
| self._msg_id_lock = asyncio.Lock() | ||
|
|
||
| # Keep track of the maximum message size | ||
| self._max_msg_size = 4 * 1024 * 1024 | ||
|
|
||
| # Event object for efficient reconnection waiting | ||
| self._connected_event = asyncio.Event() | ||
|
|
||
| @property | ||
| def connected(self) -> bool: | ||
| """Return if we're currently connected.""" | ||
|
|
@@ -103,6 +111,9 @@ def version(self) -> str: | |
| """Return version of connected Home Assistant instance.""" | ||
| return self._version | ||
|
|
||
| async def wait_for_connection(self): | ||
| await self._connected_event.wait() | ||
|
|
||
| async def subscribe_events( | ||
| self, cb_func: Callable[[Event], None], event_type: str = MATCH_ALL | ||
| ) -> Callable: | ||
|
|
@@ -170,35 +181,44 @@ async def call_service( | |
| msg["service_data"] = service_data | ||
| if target: | ||
| msg["target"] = target | ||
| return await self.send_command(msg) | ||
| return await self.send_retryable_command(msg) | ||
|
|
||
| async def get_states(self) -> list[State]: | ||
| """Get dump of the current states within Home Assistant.""" | ||
| return await self.send_command("get_states") | ||
| return await self.send_retryable_command("get_states") | ||
|
|
||
| async def get_config(self) -> list[Config]: | ||
| """Get dump of the current config in Home Assistant.""" | ||
| return await self.send_command("get_states") | ||
| return await self.send_retryable_command("get_config") | ||
|
|
||
| async def get_services(self) -> dict[str, dict[str, Any]]: | ||
| """Get dump of the current services in Home Assistant.""" | ||
| return await self.send_command("get_services") | ||
| return await self.send_retryable_command("get_services") | ||
|
|
||
| async def get_area_registry(self) -> list[Area]: | ||
| """Get Area Registry.""" | ||
| return await self.send_command("config/area_registry/list") | ||
| return await self.send_retryable_command("config/area_registry/list") | ||
|
|
||
| async def get_device_registry(self) -> list[Device]: | ||
| """Get Device Registry.""" | ||
| return await self.send_command("config/device_registry/list") | ||
| return await self.send_retryable_command("config/device_registry/list") | ||
|
|
||
| async def get_entity_registry(self) -> list[Entity]: | ||
| """Get Entity Registry.""" | ||
| return await self.send_command("config/entity_registry/list") | ||
| return await self.send_retryable_command("config/entity_registry/list") | ||
|
|
||
| async def get_entity_registry_entry(self, entity_id: str) -> Entity: | ||
| """Get single entry from Entity Registry.""" | ||
| return await self.send_command("config/entity_registry/get", entity_id=entity_id) | ||
| return await self.send_retryable_command("config/entity_registry/get", entity_id=entity_id) | ||
|
|
||
| async def send_retryable_command(self, command: str, **kwargs: dict[str, Any]) -> CommandResultData: | ||
| """Send a command to the HA websocket and return response. Retry on failure.""" | ||
| while True: | ||
| try: | ||
| return await self.send_command(command, **kwargs) | ||
| except ConnectionFailedDueToLargeMessage: | ||
| LOGGER.debug("Connection failed due to large message - waiting for reconnect and then retrying") | ||
| await self.wait_for_connection() | ||
|
|
||
| async def send_command(self, command: str, **kwargs: dict[str, Any]) -> CommandResultData: | ||
| """Send a command to the HA websocket and return response.""" | ||
|
|
@@ -212,6 +232,10 @@ async def send_command(self, command: str, **kwargs: dict[str, Any]) -> CommandR | |
| await self._send_json_message(message) | ||
| try: | ||
| return await future | ||
| except CancelledError as e: | ||
| if len(e.args) > 0: | ||
| # Raise the inner exception | ||
| raise e.args[0] from e | ||
|
Comment on lines
+236
to
+238
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this ? It is not very common to catch the CancelledError, why are you doing this like this ?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The websocket connection shuts down when a TooLarge error message is returned, which means we experience the failure as a cancellation. But we need to know what actually caused it to shutdown, which is packaged as the first parameter of The relevant usage is here: async def send_retryable_command(self, command: str, **kwargs: dict[str, Any]) -> CommandResultData:
"""Send a command to the HA websocket and return response. Retry on failure."""
while True:
try:
return await self.send_command(command, **kwargs)
except ConnectionFailedDueToLargeMessage:
LOGGER.debug("Connection failed due to large message - waiting for reconnect and then retrying")
await self.wait_for_connection()where we want to either keep raising the exception (because it's not something known harmless), or handle it because we actually know what it is (ConnectionFailedDueToLargeMessage). Otherwise calling code doesn't know why the connection was ended. Of note: this catch could be other things, but I don't know what they should be at the moment (a reasonable one would be something like "ConnectionLostRetriesExceeded" or something.) |
||
| finally: | ||
| self._result_futures.pop(message_id) | ||
|
|
||
|
|
@@ -263,7 +287,8 @@ async def connect(self) -> None: | |
| ws_token = self._token or os.environ.get("HASSIO_TOKEN") | ||
| LOGGER.debug("Connecting to Home Assistant Websocket API on %s", ws_url) | ||
| try: | ||
| self._client = await self._http_session.ws_connect(ws_url, heartbeat=55) | ||
| self._client = await self._http_session.ws_connect(ws_url, heartbeat=55, | ||
| max_msg_size=self._max_msg_size) | ||
| version_msg: AuthRequiredMessage = await self._client.receive_json() | ||
| self._version = version_msg["ha_version"] | ||
| # send authentication | ||
|
|
@@ -285,6 +310,16 @@ async def connect(self) -> None: | |
| ) | ||
| # start task to handle incoming messages | ||
| self._loop.create_task(self._process_messages()) | ||
| # notify watchers we're connected | ||
| self._connected_event.set() | ||
|
|
||
| async def _close_client(self) -> None: | ||
| """Invoke the underlying client close operation and clear the connected state""" | ||
| # Block any new users from sending messages - close has been called | ||
| self._connected_event.clear() | ||
| if not self._client.closed: | ||
| await self._client.close() | ||
|
|
||
|
|
||
| async def disconnect(self) -> None: | ||
| """Disconnect the client.""" | ||
|
|
@@ -294,11 +329,12 @@ async def disconnect(self) -> None: | |
| return | ||
|
|
||
| self._shutdown_complete_event = asyncio.Event() | ||
| await self._client.close() | ||
| await self._close_client() | ||
| await self._shutdown_complete_event.wait() | ||
|
|
||
| async def _process_messages(self) -> None: | ||
| """Start listening to the websocket.""" | ||
| terminating_exception = None | ||
| try: | ||
| while not self._client.closed: | ||
| msg = await self._client.receive() | ||
|
|
@@ -307,6 +343,18 @@ async def _process_messages(self) -> None: | |
| break | ||
|
|
||
| if msg.type == WSMsgType.ERROR: | ||
| # Home Assistant can produce some *very* large messages, and there's | ||
| # no sign of a chunking API turning up soon. So check if we're losing | ||
| # the connection due to message size. | ||
| if msg.data.code == aiohttp.WSCloseCode.MESSAGE_TOO_BIG: | ||
| # Parse the attempted size out, and schedule a reconnect. | ||
| if (m := re.match(r"Message size (\d+)", msg.data.args[1])) is not None: | ||
| attempted_message_size = int(m.group(1)) | ||
| # Set to 2x what they attempted to send us so hopefully we'll succeed | ||
| # on reconnect. | ||
| self._max_msg_size = attempted_message_size * 2 | ||
| raise ConnectionFailedDueToLargeMessage() | ||
|
|
||
|
Comment on lines
+349
to
+357
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you have a very specific usecase here, it is not very common to have THIS MANY entities.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The theory is that the user is already aware, or doesn't care - which for me is true because my HA installation operates just fine (hence how I discovered this issue - turning on iBeacons probably did it and I've not cleaned up yet). So the logic is that if HA doesn't hit this (and it doesn't seem to have any limit other then browser memory) then the user probably doesn't care either for any practical case because this is HA-specific. Basically: if the normal HA client isn't breaking (which is much heavier on memory) then the Python client shouldn't break either - it would Just Work(TM). The other side of that is an obvious reason to use this library would be to do something like scripting a bulk clean up entities as well. EDIT: Basically I'd argue usability wise, the constraint goes the other way - if reconnects are automatic, then updating receive size limits should be automatic too unless the the user specifically has a reason they want to constrain it.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, although I'm a bit worried about the readability/maintainability of this client with this added overhead. Another, more simple approach could be to just have a config param to set the default message size for more advanced setups. |
||
| raise ConnectionFailed() | ||
|
|
||
| if msg.type != WSMsgType.TEXT: | ||
|
|
@@ -321,20 +369,26 @@ async def _process_messages(self) -> None: | |
| LOGGER.debug("Received message:\n%s\n", pprint.pformat(msg)) | ||
|
|
||
| self._handle_incoming_message(data) | ||
| except Exception as e: | ||
| terminating_exception = e | ||
|
|
||
| finally: | ||
| LOGGER.debug("Listen completed. Cleaning up") | ||
| if terminating_exception is not None: | ||
| LOGGER.debug("Listen finished with exception - cancelling futures with exception name: %s", | ||
| type(terminating_exception).__name__) | ||
| else: | ||
| LOGGER.debug("Listen completing normally") | ||
|
|
||
| for future in self._result_futures.values(): | ||
| future.cancel() | ||
| LOGGER.debug("Listen completed. Cleaning up") | ||
| await self._close_client() | ||
|
|
||
| if not self._client.closed: | ||
| await self._client.close() | ||
| LOGGER.debug("Connection closed - cancelling futures") | ||
| for future in self._result_futures.values(): | ||
| future.cancel(msg=terminating_exception) | ||
|
|
||
| if self._shutdown_complete_event: | ||
| self._shutdown_complete_event.set() | ||
| else: | ||
| self._on_connection_lost() | ||
| if self._shutdown_complete_event: | ||
| self._shutdown_complete_event.set() | ||
| else: | ||
| self._on_connection_lost() | ||
|
|
||
| def _handle_incoming_message(self, msg: Message) -> None: | ||
| """Handle incoming message.""" | ||
|
|
@@ -407,12 +461,8 @@ async def auto_reconnect(): | |
| attempts = 0 | ||
| sleep_time = 2 | ||
| while True: | ||
| # Try to reconnect right away in case this is recoverable immediately... | ||
| attempts += 1 | ||
| if attempts > 20: | ||
| sleep_time = 60 | ||
| elif sleep_time > 10: | ||
| sleep_time = 10 | ||
| await asyncio.sleep(sleep_time) | ||
| try: | ||
| await self.connect() | ||
| # resubscribe all subscriptions | ||
|
|
@@ -424,6 +474,12 @@ async def auto_reconnect(): | |
| return | ||
| except CannotConnect: | ||
| pass | ||
| # Failed, go to sleep now... | ||
| if attempts > 20: | ||
| sleep_time = 60 | ||
| elif sleep_time > 10: | ||
| sleep_time = 10 | ||
| await asyncio.sleep(sleep_time) | ||
| if attempts >= 30: | ||
| LOGGER.warning( | ||
| "Still could not reconnect after %s attempts, is the server alive ?", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if I'm a fan of this pattern. Why not introduce an argument to specify if the command may be retried ?
Also, I think it would be good to specify a number of retries. So for example introduce an argument "retries:int =3" and thus by default retry a command 3 times on common connection errors