Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial orjson support take 3 #73849

Merged
merged 16 commits into from
Jun 22, 2022
Next Next commit
Initial orjson support take 2
Still need to work out problem building wheels

--

Redux of #72754 / #32153 Now possible since the following is solved:
ijl/orjson#220 (comment)

This implements orjson where we use our default encoder.  This does not implement orjson where `ExtendedJSONEncoder` is used as these areas tend to be called far less frequently.  If its desired, this could be done in a followup, but it seemed like a case of diminishing returns (except maybe for large diagnostics files, or traces, but those are not expected to be downloaded frequently).

Areas where this makes a perceptible difference:
- Anything that subscribes to entities (Initial subscribe_entities payload)
- Initial download of registries on first connection / restore
- History queries
- Saving states to the database
- Large logbook queries
- Anything that subscribes to events (appdaemon)

Cavets:
orjson supports serializing dataclasses natively (and much faster) which
eliminates the need to implement `as_dict` in many places
when the data is already in a dataclass. This works
well as long as all the data in the dataclass can also
be serialized. I audited all places where we have an `as_dict`
for a dataclass and found only backups needs to be adjusted (support for `Path` needed to be added for backups).  I was a little bit worried about `SensorExtraStoredData` with `Decimal` but it all seems to work out from since it converts it before it gets to the json encoding cc @dgomes

If it turns out to be a problem we can disable this
with option |= [orjson.OPT_PASSTHROUGH_DATACLASS](https://github.com/ijl/orjson#opt_passthrough_dataclass) and it
will fallback to `as_dict`

Its quite impressive for history queries
<img width="1271" alt="Screen_Shot_2022-05-30_at_23_46_30" src="https://user-images.githubusercontent.com/663432/171145699-661ad9db-d91d-4b2d-9c1a-9d7866c03a73.png">
  • Loading branch information
bdraco committed Jun 10, 2022
commit df1f37f495e654891dbbe3d39e8392bb78b418a7
2 changes: 1 addition & 1 deletion homeassistant/components/history/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
)
from homeassistant.components.recorder.util import session_scope
from homeassistant.components.websocket_api import messages
from homeassistant.components.websocket_api.const import JSON_DUMP
from homeassistant.core import HomeAssistant
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA
from homeassistant.helpers.json import JSON_DUMP
from homeassistant.helpers.typing import ConfigType
import homeassistant.util.dt as dt_util

Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/logbook/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from homeassistant.components.recorder import get_instance
from homeassistant.components.websocket_api import messages
from homeassistant.components.websocket_api.connection import ActiveConnection
from homeassistant.components.websocket_api.const import JSON_DUMP
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
from homeassistant.helpers.entityfilter import EntityFilter
from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.helpers.json import JSON_DUMP
import homeassistant.util.dt as dt_util

from .const import LOGBOOK_ENTITIES_FILTER
Expand Down
10 changes: 3 additions & 7 deletions homeassistant/components/recorder/const.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Recorder constants."""

from functools import partial
import json
from typing import Final

from homeassistant.backports.enum import StrEnum
from homeassistant.const import ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.json import ( # noqa: F401 pylint: disable=unused-import
JSON_DUMP,
)

DATA_INSTANCE = "recorder_instance"
SQLITE_URL_PREFIX = "sqlite://"
Expand All @@ -27,8 +25,6 @@

DB_WORKER_PREFIX = "DbWorker"

JSON_DUMP: Final = partial(json.dumps, cls=JSONEncoder, separators=(",", ":"))

ALL_DOMAIN_EXCLUDE_ATTRS = {ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES}

ATTR_KEEP_DAYS = "keep_days"
Expand Down
10 changes: 6 additions & 4 deletions homeassistant/components/recorder/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,19 +754,20 @@ def _process_non_state_changed_event_into_session(self, event: Event) -> None:
return

try:
shared_data = EventData.shared_data_from_event(event)
shared_data_bytes = EventData.shared_data_bytes_from_event(event)
except (TypeError, ValueError) as ex:
_LOGGER.warning("Event is not JSON serializable: %s: %s", event, ex)
return

shared_data = shared_data_bytes.decode("utf-8")
# Matching attributes found in the pending commit
if pending_event_data := self._pending_event_data.get(shared_data):
dbevent.event_data_rel = pending_event_data
# Matching attributes id found in the cache
elif data_id := self._event_data_ids.get(shared_data):
dbevent.data_id = data_id
else:
data_hash = EventData.hash_shared_data(shared_data)
data_hash = EventData.hash_shared_data_bytes(shared_data_bytes)
# Matching attributes found in the database
if data_id := self._find_shared_data_in_db(data_hash, shared_data):
self._event_data_ids[shared_data] = dbevent.data_id = data_id
Expand All @@ -785,7 +786,7 @@ def _process_state_changed_event_into_session(self, event: Event) -> None:
assert self.event_session is not None
try:
dbstate = States.from_event(event)
shared_attrs = StateAttributes.shared_attrs_from_event(
shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event(
event, self._exclude_attributes_by_domain
)
except (TypeError, ValueError) as ex:
Expand All @@ -796,6 +797,7 @@ def _process_state_changed_event_into_session(self, event: Event) -> None:
)
return

shared_attrs = shared_attrs_bytes.decode("utf-8")
dbstate.attributes = None
# Matching attributes found in the pending commit
if pending_attributes := self._pending_state_attributes.get(shared_attrs):
Expand All @@ -804,7 +806,7 @@ def _process_state_changed_event_into_session(self, event: Event) -> None:
elif attributes_id := self._state_attributes_ids.get(shared_attrs):
dbstate.attributes_id = attributes_id
else:
attr_hash = StateAttributes.hash_shared_attrs(shared_attrs)
attr_hash = StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes)
# Matching attributes found in the database
if attributes_id := self._find_shared_attr_in_db(attr_hash, shared_attrs):
dbstate.attributes_id = attributes_id
Expand Down
57 changes: 29 additions & 28 deletions homeassistant/components/recorder/db_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from collections.abc import Callable
from datetime import datetime, timedelta
import json
import logging
from typing import Any, cast

import ciso8601
from fnvhash import fnv1a_32
import orjson
from sqlalchemy import (
JSON,
BigInteger,
Expand Down Expand Up @@ -39,9 +39,10 @@
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
from homeassistant.helpers.json import JSON_DUMP, json_bytes
import homeassistant.util.dt as dt_util

from .const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP
from .const import ALL_DOMAIN_EXCLUDE_ATTRS
from .models import StatisticData, StatisticMetaData, process_timestamp

# SQLAlchemy Schema
Expand Down Expand Up @@ -124,7 +125,7 @@ def literal_processor(self, dialect: str) -> Callable[[Any], str]:

def process(value: Any) -> str:
"""Dump json."""
return json.dumps(value)
return JSON_DUMP(value)

return process

Expand Down Expand Up @@ -187,15 +188,15 @@ def to_native(self, validate_entity_id: bool = True) -> Event | None:
try:
return Event(
self.event_type,
json.loads(self.event_data) if self.event_data else {},
orjson.loads(self.event_data) if self.event_data else {},
EventOrigin(self.origin)
if self.origin
else EVENT_ORIGIN_ORDER[self.origin_idx],
process_timestamp(self.time_fired),
context=context,
)
except ValueError:
# When json.loads fails
# When orjson.loads fails
_LOGGER.exception("Error converting to event: %s", self)
return None

Expand Down Expand Up @@ -223,25 +224,26 @@ def __repr__(self) -> str:
@staticmethod
def from_event(event: Event) -> EventData:
"""Create object from an event."""
shared_data = JSON_DUMP(event.data)
shared_data = json_bytes(event.data)
return EventData(
shared_data=shared_data, hash=EventData.hash_shared_data(shared_data)
shared_data=shared_data.decode("utf-8"),
hash=EventData.hash_shared_data_bytes(shared_data),
)

@staticmethod
def shared_data_from_event(event: Event) -> str:
"""Create shared_attrs from an event."""
return JSON_DUMP(event.data)
def shared_data_bytes_from_event(event: Event) -> bytes:
"""Create shared_data from an event."""
return json_bytes(event.data)

@staticmethod
def hash_shared_data(shared_data: str) -> int:
def hash_shared_data_bytes(shared_data_bytes: bytes) -> int:
"""Return the hash of json encoded shared data."""
return cast(int, fnv1a_32(shared_data.encode("utf-8")))
return cast(int, fnv1a_32(shared_data_bytes))

def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object."""
try:
return cast(dict[str, Any], json.loads(self.shared_data))
return cast(dict[str, Any], orjson.loads(self.shared_data))
except ValueError:
_LOGGER.exception("Error converting row to event data: %s", self)
return {}
Expand Down Expand Up @@ -328,9 +330,9 @@ def to_native(self, validate_entity_id: bool = True) -> State | None:
parent_id=self.context_parent_id,
)
try:
attrs = json.loads(self.attributes) if self.attributes else {}
attrs = orjson.loads(self.attributes) if self.attributes else {}
except ValueError:
# When json.loads fails
# When orjson.loads fails
_LOGGER.exception("Error converting row to state: %s", self)
return None
if self.last_changed is None or self.last_changed == self.last_updated:
Expand Down Expand Up @@ -376,40 +378,39 @@ def from_event(event: Event) -> StateAttributes:
"""Create object from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
dbstate = StateAttributes(
shared_attrs="{}" if state is None else JSON_DUMP(state.attributes)
)
dbstate.hash = StateAttributes.hash_shared_attrs(dbstate.shared_attrs)
attr_bytes = b"{}" if state is None else json_bytes(state.attributes)
dbstate = StateAttributes(shared_attrs=attr_bytes.decode("utf-8"))
dbstate.hash = StateAttributes.hash_shared_attrs_bytes(attr_bytes)
return dbstate

@staticmethod
def shared_attrs_from_event(
def shared_attrs_bytes_from_event(
event: Event, exclude_attrs_by_domain: dict[str, set[str]]
) -> str:
) -> bytes:
"""Create shared_attrs from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
if state is None:
return "{}"
return b"{}"
domain = split_entity_id(state.entity_id)[0]
exclude_attrs = (
exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS
)
return JSON_DUMP(
return json_bytes(
{k: v for k, v in state.attributes.items() if k not in exclude_attrs}
)

@staticmethod
def hash_shared_attrs(shared_attrs: str) -> int:
"""Return the hash of json encoded shared attributes."""
return cast(int, fnv1a_32(shared_attrs.encode("utf-8")))
def hash_shared_attrs_bytes(shared_attrs_bytes: bytes) -> int:
"""Return the hash of orjson encoded shared attributes."""
return cast(int, fnv1a_32(shared_attrs_bytes))

def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object."""
try:
return cast(dict[str, Any], json.loads(self.shared_attrs))
return cast(dict[str, Any], orjson.loads(self.shared_attrs))
except ValueError:
# When json.loads fails
# When orjson.loads fails
_LOGGER.exception("Error converting row to state attributes: %s", self)
return {}

Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/recorder/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from __future__ import annotations

from datetime import datetime
import json
import logging
from typing import Any, TypedDict, overload

import orjson
from sqlalchemy.engine.row import Row

from homeassistant.components.websocket_api.const import (
Expand Down Expand Up @@ -253,7 +253,7 @@ def decode_attributes_from_row(
if not source or source == EMPTY_JSON_OBJECT:
return {}
try:
attr_cache[source] = attributes = json.loads(source)
attr_cache[source] = attributes = orjson.loads(source)
except ValueError:
_LOGGER.exception("Error converting row to state attributes: %s", source)
attr_cache[source] = attributes = {}
Expand Down
18 changes: 9 additions & 9 deletions homeassistant/components/websocket_api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
TrackTemplateResult,
async_track_template_result,
)
from homeassistant.helpers.json import ExtendedJSONEncoder
from homeassistant.helpers.json import JSON_DUMP, ExtendedJSONEncoder
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations
Expand Down Expand Up @@ -241,13 +241,13 @@ def handle_get_states(
# to succeed for the UI to show.
response = messages.result_message(msg["id"], states)
try:
connection.send_message(const.JSON_DUMP(response))
connection.send_message(JSON_DUMP(response))
return
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(response, dump=const.JSON_DUMP)
find_paths_unserializable_data(response, dump=JSON_DUMP)
),
)
del response
Expand All @@ -256,13 +256,13 @@ def handle_get_states(
serialized = []
for state in states:
try:
serialized.append(const.JSON_DUMP(state))
serialized.append(JSON_DUMP(state))
except (ValueError, TypeError):
# Error is already logged above
pass

# We now have partially serialized states. Craft some JSON.
response2 = const.JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"]))
response2 = JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"]))
response2 = response2.replace('"TO_REPLACE"', ", ".join(serialized))
connection.send_message(response2)

Expand Down Expand Up @@ -315,13 +315,13 @@ def forward_entity_changes(event: Event) -> None:
# to succeed for the UI to show.
response = messages.event_message(msg["id"], data)
try:
connection.send_message(const.JSON_DUMP(response))
connection.send_message(JSON_DUMP(response))
return
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(response, dump=const.JSON_DUMP)
find_paths_unserializable_data(response, dump=JSON_DUMP)
),
)
del response
Expand All @@ -330,14 +330,14 @@ def forward_entity_changes(event: Event) -> None:
cannot_serialize: list[str] = []
for entity_id, state_dict in add_entities.items():
try:
const.JSON_DUMP(state_dict)
JSON_DUMP(state_dict)
except (ValueError, TypeError):
cannot_serialize.append(entity_id)

for entity_id in cannot_serialize:
del add_entities[entity_id]

connection.send_message(const.JSON_DUMP(messages.event_message(msg["id"], data)))
connection.send_message(JSON_DUMP(messages.event_message(msg["id"], data)))


@decorators.websocket_command({vol.Required("type"): "get_services"})
Expand Down
3 changes: 2 additions & 1 deletion homeassistant/components/websocket_api/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from homeassistant.auth.models import RefreshToken, User
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers.json import JSON_DUMP

from . import const, messages

Expand Down Expand Up @@ -56,7 +57,7 @@ def send_result(self, msg_id: int, result: Any | None = None) -> None:
async def send_big_result(self, msg_id: int, result: Any) -> None:
"""Send a result message that would be expensive to JSON serialize."""
content = await self.hass.async_add_executor_job(
const.JSON_DUMP, messages.result_message(msg_id, result)
JSON_DUMP, messages.result_message(msg_id, result)
)
self.send_message(content)

Expand Down
7 changes: 0 additions & 7 deletions homeassistant/components/websocket_api/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
import asyncio
from collections.abc import Awaitable, Callable
from concurrent import futures
from functools import partial
import json
from typing import TYPE_CHECKING, Any, Final

from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import JSONEncoder

if TYPE_CHECKING:
from .connection import ActiveConnection # noqa: F401
Expand Down Expand Up @@ -53,10 +50,6 @@
# Data used to store the current connection list
DATA_CONNECTIONS: Final = f"{DOMAIN}.connections"

JSON_DUMP: Final = partial(
json.dumps, cls=JSONEncoder, allow_nan=False, separators=(",", ":")
)

COMPRESSED_STATE_STATE = "s"
COMPRESSED_STATE_ATTRIBUTES = "a"
COMPRESSED_STATE_CONTEXT = "c"
Expand Down
Loading