Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions homeassistant_api/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,29 @@ class HomeassistantAPIError(Exception):
class RequestError(HomeassistantAPIError):
"""Error raised when an issue occurs when requesting to Homeassistant."""

def __init__(
self, data: Optional[str], /, url: str, message: Optional[str] = None
) -> None:
if message is not None:
super().__init__(
message
+ f" {url!r}"
+ (f" with data: {data!r}" if data is not None else "")
)
elif data is None:
super().__init__(f"An error occurred while making the request to {url!r}")
else:
super().__init__(
f"An error occurred while making the request to {url!r} with data: {data!r}"
)


class RequestTimeoutError(RequestError):
"""Error raised when a request times out."""

def __init__(self, message: str, url: str) -> None:
super().__init__(None, url, message)


class ResponseError(HomeassistantAPIError):
"""Error raised when an issue occurs in a response from Homeassistant."""
Expand Down
2 changes: 1 addition & 1 deletion homeassistant_api/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def process(self) -> Any:
if status_code in (200, 201):
return self.process_content(async_=async_)
if status_code == 400:
raise RequestError(content)
raise RequestError(content, url=self._response.url) # type: ignore
if status_code == 401:
raise UnauthorizedError()
if status_code == 404:
Expand Down
11 changes: 8 additions & 3 deletions homeassistant_api/rawasyncclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ async def __aexit__(self, _, __, ___):
async def async_request(
self,
path: str,
*,
params: str = "", # should be a string of query parameters from construct_params()
method: str = "GET",
headers: Optional[Dict[str, str]] = None,
**kwargs,
Expand All @@ -103,14 +105,15 @@ async def async_request(
return await self.async_response_logic(
await self.async_cache_session.request(
method,
self.endpoint(path),
self.endpoint(path) + f"?{params}" * bool(params),
headers=self.prepare_headers(headers),
**kwargs,
)
)
except asyncio.exceptions.TimeoutError as err:
raise RequestTimeoutError(
f'Home Assistant did not respond in time (timeout: {kwargs.get("timeout", 300)} sec)'
f'Home Assistant did not respond in time (timeout: {kwargs.get("timeout", 300)} sec)',
self.endpoint(path) + f"?{params}" * bool(params),
) from err

@staticmethod
Expand Down Expand Up @@ -143,7 +146,9 @@ async def async_get_logbook_entries(
:code:`GET /api/logbook/<timestamp>`
"""
params, url = self.prepare_get_logbook_entry_params(*args, **kwargs)
data = await self.async_request(url, params=params)
data = await self.async_request(
url, params=self.construct_params(cast(Dict[str, Optional[str]], params))
)
for entry in data:
yield LogbookEntry.model_validate(entry)

Expand Down
41 changes: 31 additions & 10 deletions homeassistant_api/rawbaseclient.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Module for parent RawWrapper class"""

from datetime import datetime
from datetime import datetime, timedelta
from posixpath import join
from typing import Any, Dict, Iterable, Optional, Tuple, Union
from urllib.parse import quote_plus

from .models import Entity

Expand Down Expand Up @@ -62,8 +63,16 @@ def prepare_headers(

@staticmethod
def construct_params(params: Dict[str, Optional[str]]) -> str:
"""Custom method for constructing non-standard query strings"""
return "&".join([k if v is None else f"{k}={v}" for k, v in params.items()])
"""
Custom method for constructing non-standard query strings.

For keys with corresponding None values, the query string will be key only (i.e. :code:`?key1&key2`).
For keys with corresponding non-None values, the query string will be key-value pairs (i.e. :code:`?key1=value1&key2=value2`).
To have an empty value use an empty string :code:`""` (i.e. :code:`?key1=&key2=value2`).
"""
return "&".join(
[k if v is None else f"{k}={quote_plus(v)}" for k, v in params.items()]
)

@staticmethod
def prepare_get_entity_histories_params(
Expand All @@ -73,20 +82,32 @@ def prepare_get_entity_histories_params(
end_timestamp: Optional[datetime] = None,
significant_changes_only: bool = False,
) -> Tuple[Dict[str, Optional[str]], str]:
"""Pre-logic for `Client.get_entity_histories` and `Client.async_get_entity_histories`."""
"""
Pre-logic for :py:meth:`Client.get_entity_histories` and :py:meth:`Client.async_get_entity_histories`.

Ensure timestamps

* use second resolution (microseconds are truncated)
* are timezone-aware
* are URL-encoded (as :py:meth:`construct_params` is used instead of request's default parameter encoding)
"""
params: Dict[str, Optional[str]] = {}
if entities is not None:
params["filter_entity_id"] = ",".join([ent.entity_id for ent in entities])
if end_timestamp is not None:
params["end_time"] = (
end_timestamp.isoformat()
) # Params are automatically URL encoded
if significant_changes_only:
params["significant_changes_only"] = None
if start_timestamp is not None:
start_timestamp = start_timestamp.replace(microsecond=0)
if start_timestamp.tzinfo is None:
start_timestamp = start_timestamp.astimezone()
url = join("history/period/", start_timestamp.isoformat())
else:
url = "history/period"
if end_timestamp is not None:
end_timestamp = end_timestamp.replace(microsecond=0) + timedelta(seconds=1)
if end_timestamp.tzinfo is None:
end_timestamp = end_timestamp.astimezone()
params["end_time"] = end_timestamp.isoformat()
if significant_changes_only:
params["significant_changes_only"] = None
return params, url

@staticmethod
Expand Down
11 changes: 8 additions & 3 deletions homeassistant_api/rawclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __exit__(self, _, __, ___) -> None:
def request(
self,
path: str,
*,
params: str = "", # should be a string of query parameters from construct_params()
method="GET",
headers: Optional[Dict[str, str]] = None,
decode_bytes: bool = True,
Expand All @@ -99,13 +101,14 @@ def request(
if self.cache_session:
resp = self.cache_session.request(
method,
self.endpoint(path),
self.endpoint(path) + f"?{params}" * bool(params),
headers=self.prepare_headers(headers),
**kwargs,
)
except requests.exceptions.Timeout as err:
raise RequestTimeoutError(
f'Home Assistant did not respond in time (timeout: {kwargs.get("timeout", 300)} sec)'
f'Home Assistant did not respond in time (timeout: {kwargs.get("timeout", 300)} sec)',
url=self.endpoint(path) + f"?{params}" * bool(params),
) from err
return self.response_logic(response=resp, decode_bytes=decode_bytes)

Expand Down Expand Up @@ -139,7 +142,9 @@ def get_logbook_entries(
:code:`GET /api/logbook/<timestamp>`
"""
params, url = self.prepare_get_logbook_entry_params(*args, **kwargs)
data = self.request(url, params=params)
data = self.request(
url, params=self.construct_params(cast(Dict[str, Optional[str]], params))
)
for entry in data:
yield LogbookEntry.model_validate(entry)

Expand Down
32 changes: 16 additions & 16 deletions homeassistant_api/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@

class WebsocketClient(RawWebsocketClient):
"""

The main class for interactign with the Home Assistant WebSocket API client.

Here's a quick example of how to use the :py:class:`WebsocketClient` class:

.. code-block:: python

from homeassistant_api import WebsocketClient
Expand Down Expand Up @@ -66,7 +66,7 @@ def get_rendered_template(self, template: str) -> str:
def get_config(self) -> dict[str, Any]:
"""
Get the Home Assistant configuration.

Sends command :code:`{"type": "get_config", ...}`.
"""
return cast(
Expand All @@ -80,7 +80,7 @@ def get_config(self) -> dict[str, Any]:
def get_states(self) -> Tuple[State, ...]:
"""
Get a list of states.

Sends command :code:`{"type": "get_states", ...}`.
"""
return tuple(
Expand Down Expand Up @@ -170,7 +170,7 @@ def get_domains(self) -> dict[str, Domain]:
Get a list of services that Home Assistant offers (organized into a dictionary of service domains).

For example, the service :code:`light.turn_on` would be in the domain :code:`light`.

Sends command :code:`{"type": "get_services", ...}`.
"""
resp = self.recv(self.send("get_services"))
Expand Down Expand Up @@ -203,7 +203,7 @@ def trigger_service(
) -> None:
"""
Trigger a service (that doesn't return a response).

Sends command :code:`{"type": "call_service", ...}`.
"""
params = {
Expand Down Expand Up @@ -236,7 +236,7 @@ def trigger_service_with_response(
) -> dict[str, Any]:
"""
Trigger a service (that returns a response) and return the response.

Sends command :code:`{"type": "call_service", ...}`.
"""
params = {
Expand All @@ -261,7 +261,7 @@ def listen_events(
Listen for all events of a certain type.

For example, to listen for all events of type `test_event`:

.. code-block:: python

with ws_client.listen_events("test_event") as events:
Expand All @@ -275,7 +275,7 @@ def listen_events(
def _subscribe_events(self, event_type: Optional[str]) -> int:
"""
Subscribe to all events of a certain type.


Sends command :code:`{"type": "subscribe_events", ...}`.
"""
Expand All @@ -292,15 +292,15 @@ def listen_trigger(

For example, in Home Assistant Automations we can subscribe to a state trigger for a light entity with YAML:

.. code-block:: yaml
.. code-block:: yaml

triggers:
# ...
- trigger: state
entity_id: light.kitchen

To subscribe to that same state trigger with :py:class:`WebsocketClient` instead

.. code-block:: python

with ws_client.listen_trigger("state", entity_id="light.kitchen") as trigger:
Expand All @@ -309,7 +309,7 @@ def listen_trigger(
if <some_condition>:
break
# exiting the context manager unsubscribes from the trigger

Woohoo! We can now listen to triggers in Python code!
"""
subscription = self._subscribe_trigger(trigger, **trigger_fields)
Expand All @@ -325,7 +325,7 @@ def listen_trigger(
def _subscribe_trigger(self, trigger: str, **trigger_fields) -> int:
"""
Return the subscription id of the trigger we subscribe to.

Sends command :code:`{"type": "subscribe_trigger", ...}`.
"""
return self.recv(
Expand All @@ -351,7 +351,7 @@ def _wait_for(
def _unsubscribe(self, subcription_id: int) -> None:
"""
Unsubscribe from all events of a certain type.

Sends command :code:`{"type": "unsubscribe_events", ...}`.
"""
resp = self.recv(self.send("unsubscribe_events", subscription=subcription_id))
Expand All @@ -361,7 +361,7 @@ def _unsubscribe(self, subcription_id: int) -> None:
def fire_event(self, event_type: str, **event_data) -> Context:
"""
Fire an event.

Sends command :code:`{"type": "fire_event", ...}`.
"""
params: dict[str, Any] = {"event_type": event_type}
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from homeassistant_api import Client, WebsocketClient

logging.basicConfig(level=logging.INFO)

TIMEOUT = 300


Expand Down
20 changes: 18 additions & 2 deletions tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def test_get_logbook_entries(cached_client: Client) -> None:

async def test_async_get_logbook_entries(async_cached_client: Client) -> None:
"""Tests the `GET /api/logbook/<timestamp>` endpoint."""
async for entry in async_cached_client.async_get_logbook_entries():
async for entry in async_cached_client.async_get_logbook_entries(
filter_entities="sun.sun",
start_timestamp=datetime(2020, 1, 1),
end_timestamp=datetime.now(),
):
assert entry


Expand All @@ -64,12 +68,18 @@ def test_get_entity_histories(cached_client: Client) -> None:
assert sun is not None
for history in cached_client.get_entity_histories(
(sun,),
end_timestamp=datetime(2023, 1, 1),
end_timestamp=datetime.now(), # test for microsecond truncation
start_timestamp=datetime(2020, 1, 1),
significant_changes_only=True,
):
for state in history.states:
assert isinstance(state, State)
break
else:
raise AssertionError("No states in entity history found.")
break
else:
raise AssertionError("No history found.")


async def test_async_get_entity_histories(async_cached_client: Client) -> None:
Expand All @@ -79,6 +89,12 @@ async def test_async_get_entity_histories(async_cached_client: Client) -> None:
async for history in async_cached_client.async_get_entity_histories((sun,)):
for state in history.states:
assert isinstance(state, State)
break
else:
raise AssertionError("No states in entity history found.")
break
else:
raise AssertionError("No history found.")


def test_get_rendered_template(cached_client: Client) -> None:
Expand Down