Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ repos:
types: [text]
entry: scripts/run-in-env.sh trailing-whitespace-fixer
stages: [commit, push, manual]
- id: mypy
name: mypy
entry: scripts/run-in-env.sh mypy
language: script
types: [python]
require_serial: true
files: ^(music_assistant|pylint)/.+\.py$
# - id: mypy
# name: mypy
# entry: scripts/run-in-env.sh mypy
# language: script
# types: [python]
# require_serial: true
# files: ^(hass_client|pylint)/.+\.py$
3 changes: 2 additions & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ async def connect(args: argparse.Namespace, session: ClientSession) -> None:
websocket_url = args.url.replace("http", "ws") + "/api/websocket"
async with HomeAssistantClient(websocket_url, args.token, session) as client:
await client.subscribe_events(log_events)
await asyncio.sleep(360)
# start listening will wait forever until the connection is closed/lost
await client.start_listening()


def log_events(event: Event) -> None:
Expand Down
101 changes: 41 additions & 60 deletions hass_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
provided by Home Assistant that allows for rapid development of apps
connected to Home Assistant.
"""

from __future__ import annotations

import asyncio
import logging
import os
import pprint
from collections.abc import Callable
from types import TracebackType
from typing import Any
from typing import TYPE_CHECKING, Any

import aiohttp
from aiohttp import (
ClientSession,
ClientWebSocketResponse,
Expand All @@ -26,6 +29,7 @@
AuthenticationFailed,
CannotConnect,
ConnectionFailed,
ConnectionFailedDueToLargeMessage,
FailedCommand,
InvalidMessage,
NotConnected,
Expand All @@ -46,6 +50,9 @@
State,
)

if TYPE_CHECKING:
from types import TracebackType

try:
import orjson as json

Expand All @@ -56,6 +63,7 @@
HAS_ORJSON = False

LOGGER = logging.getLogger(__package__)
MAX_MESSAGE_SIZE = 16 * 1024 * 1024 # 16MB

EventCallback = Callable[[Event], None]
EntityChangedCallback = Callable[[EntityStateEvent], None]
Expand All @@ -81,7 +89,7 @@ def __init__(
"""
self._websocket_url = websocket_url
self._token = token
self._subscriptions: dict[int, tuple[dict, SubscriptionCallback]] = {}
self._subscriptions: dict[int, tuple[dict[str, Any], SubscriptionCallback]] = {}
self._version = None
self._last_msg_id = 1
self._loop = asyncio.get_running_loop()
Expand Down Expand Up @@ -244,21 +252,20 @@ async def subscribe(
self._subscriptions[message_id] = sub

def remove_listener():
# we need to lookup the key because the subscription id can change due to reconnects
key = next((x for x, y in self._subscriptions.items() if y == sub), None)
if not key:
return
self._subscriptions.pop(key)
self._subscriptions.pop(message_id)
# try to unsubscribe
if "subscribe" not in message_base["type"]:
return
unsub_command = message_base["type"].replace("subscribe", "unsubscribe")
asyncio.create_task(self.send_command_no_wait(unsub_command, subscription=key))
asyncio.create_task(self.send_command_no_wait(unsub_command, subscription=message_id))

return remove_listener

async def connect(self) -> None:
"""Connect to the websocket server."""
if self.connected:
# already connected
return
if not self._http_session_provided and self._http_session is None:
self._http_session = ClientSession(
loop=self._loop, connector=TCPConnector(enable_cleanup_closed=True)
Expand All @@ -267,11 +274,16 @@ async def connect(self) -> None:
ws_token = self._token or os.environ.get("HASSIO_TOKEN")
LOGGER.debug("Connecting to Home Assistant Websocket API on %s", ws_url)
try:
self._client = await self._http_session.ws_connect(ws_url, heartbeat=55)
self._client = await self._http_session.ws_connect(
ws_url, heartbeat=55, max_msg_size=MAX_MESSAGE_SIZE
)
version_msg: AuthRequiredMessage = await self._client.receive_json()
self._version = version_msg["ha_version"]
# send authentication
auth_command: AuthCommandMessage = {"type": "auth", "access_token": ws_token}
auth_command: AuthCommandMessage = {
"type": "auth",
"access_token": ws_token,
}
await self._client.send_json(auth_command)
auth_result: AuthResultMessage = await self._client.receive_json()
if auth_result["type"] != "auth_ok":
Expand All @@ -287,16 +299,12 @@ async def connect(self) -> None:
self._websocket_url.split("://")[1].split("/")[0],
self.version,
)
# start task to handle incoming messages
self._loop.create_task(self._process_messages())

async def disconnect(self) -> None:
"""Disconnect the client."""
LOGGER.debug("Closing client connection")

if not self.connected:
return

LOGGER.debug("Closing client connection")
self._shutdown_complete_event = asyncio.Event()
await self._client.close()

Expand All @@ -305,8 +313,9 @@ async def disconnect(self) -> None:
self._http_session = None
await self._shutdown_complete_event.wait()

async def _process_messages(self) -> None:
"""Start listening to the websocket."""
async def start_listening(self) -> None:
"""Connect (if needed) and start listening to incoming messages from the server."""
await self.connect()
try:
while not self._client.closed:
msg = await self._client.receive()
Expand All @@ -315,6 +324,11 @@ async def _process_messages(self) -> None:
break

if msg.type == WSMsgType.ERROR:
if msg.data.code == aiohttp.WSCloseCode.MESSAGE_TOO_BIG:
# in the edge case we run into this, the lib consumer could
# decide to increase the MAX_MESSAGE_SIZE constant but messages
# bigger than 16MB are really just to big for a websocket.
raise ConnectionFailedDueToLargeMessage
raise ConnectionFailed

if msg.type != WSMsgType.TEXT:
Expand All @@ -334,7 +348,7 @@ async def _process_messages(self) -> None:

finally:
LOGGER.debug("Listen completed. Cleaning up")

# cancel all command-tasks awaiting a result
for future in self._result_futures.values():
future.cancel()

Expand All @@ -343,8 +357,6 @@ async def _process_messages(self) -> None:

if self._shutdown_complete_event:
self._shutdown_complete_event.set()
else:
self._on_connection_lost()

def _handle_incoming_message(self, msg: Message) -> None:
"""Handle incoming message."""
Expand All @@ -360,7 +372,7 @@ def _handle_incoming_message(self, msg: Message) -> None:
future.set_result(msg["result"])
return

future.set_exception(FailedCommand(msg["id"], msg["error"]["message"]))
future.set_exception(FailedCommand(msg["error"]["message"]))
return

# subscription callback
Expand Down Expand Up @@ -393,56 +405,25 @@ async def _send_json_message(self, message: dict[str, Any]) -> None:
else:
await self._client.send_json(message)

async def __aenter__(self) -> "HomeAssistantClient":
async def __aenter__(self) -> HomeAssistantClient:
"""Connect to the websocket."""
await self.connect()
return self

async def __aexit__(
self, exc_type: Exception, exc_value: str, traceback: TracebackType
) -> None:
"""Disconnect from the websocket."""
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
"""Exit context manager."""
await self.disconnect()

def __repr__(self) -> str:
"""Return the representation."""
prefix = "" if self.connected else "not "
return f"{type(self).__name__}(ws_server_url={self._websocket_url!r}, {prefix}connected)"

def _on_connection_lost(self):
"""Call when the connection gets (unexpectedly) lost."""

async def auto_reconnect():
"""Reconnect the websocket connection when connection lost."""
attempts = 0
sleep_time = 2
while True:
attempts += 1
if attempts > 20:
sleep_time = 60
elif sleep_time > 10:
sleep_time = 10
await asyncio.sleep(sleep_time)
try:
await self.connect()
# resubscribe all subscriptions
subscriptions = list(self._subscriptions.values())
self._subscriptions = {}
for sub in subscriptions:
message_id = await self._get_message_id()
await self.send_command(**sub[0], message_id=message_id)
return
except CannotConnect:
pass
if attempts >= 30:
LOGGER.warning(
"Still could not reconnect after %s attempts, is the server alive ?",
attempts,
)

LOGGER.debug("Connection lost, will auto reconnect...")
self._loop.create_task(auto_reconnect())

async def _get_message_id(self) -> int:
"""Return a new message id."""
async with self._msg_id_lock:
Expand Down
10 changes: 4 additions & 6 deletions hass_client/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def __init__(self, error: Exception | None = None) -> None:
super().__init__(f"{error}", error)


class ConnectionFailedDueToLargeMessage(ConnectionFailed):
"""Exception raised when an established connection fails due to an oversize message."""


class NotFoundError(BaseHassClientError):
"""Exception that is raised when an entity can't be found."""

Expand All @@ -55,9 +59,3 @@ class AuthenticationFailed(BaseHassClientError):

class FailedCommand(BaseHassClientError):
"""When a command has failed."""

def __init__(self, message_id: str, error_code: str):
"""Initialize a failed command error."""
super().__init__(f"Command failed: {error_code}")
self.message_id = message_id
self.error_code = error_code
Loading