diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 0c80d9792681f9..5f598c6ce40d7b 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -581,11 +581,12 @@ def async_update_statistics_metadata( *, new_statistic_id: str | UndefinedType = UNDEFINED, new_unit_of_measurement: str | None | UndefinedType = UNDEFINED, + on_done: Callable[[], None] | None = None, ) -> None: """Update statistics metadata for a statistic_id.""" self.queue_task( UpdateStatisticsMetadataTask( - statistic_id, new_statistic_id, new_unit_of_measurement + on_done, statistic_id, new_statistic_id, new_unit_of_measurement ) ) diff --git a/homeassistant/components/recorder/tasks.py b/homeassistant/components/recorder/tasks.py index 2529e8012bff55..ce51737777232c 100644 --- a/homeassistant/components/recorder/tasks.py +++ b/homeassistant/components/recorder/tasks.py @@ -71,6 +71,7 @@ def run(self, instance: Recorder) -> None: class UpdateStatisticsMetadataTask(RecorderTask): """Object to store statistics_id and unit for update of statistics metadata.""" + on_done: Callable[[], None] | None statistic_id: str new_statistic_id: str | None | UndefinedType new_unit_of_measurement: str | None | UndefinedType @@ -83,6 +84,8 @@ def run(self, instance: Recorder) -> None: self.new_statistic_id, self.new_unit_of_measurement, ) + if self.on_done: + self.on_done() @dataclass(slots=True) diff --git a/homeassistant/components/recorder/websocket_api.py b/homeassistant/components/recorder/websocket_api.py index 6ac2207b1e04f6..9e4de946c0b0e7 100644 --- a/homeassistant/components/recorder/websocket_api.py +++ b/homeassistant/components/recorder/websocket_api.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio from datetime import datetime as dt from typing import Any, Literal, cast @@ -48,6 +49,8 @@ ) from .util import PERIOD_SCHEMA, get_instance, resolve_period +UPDATE_STATISTICS_METADATA_TIME_OUT = 10 + UNIT_SCHEMA = vol.Schema( { vol.Optional("conductivity"): vol.In(ConductivityConverter.VALID_UNITS), @@ -357,17 +360,33 @@ async def ws_get_statistics_metadata( vol.Required("unit_of_measurement"): vol.Any(str, None), } ) -@callback -def ws_update_statistics_metadata( +@websocket_api.async_response +async def ws_update_statistics_metadata( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] ) -> None: """Update statistics metadata for a statistic_id. Only the normalized unit of measurement can be updated. """ + done_event = asyncio.Event() + + def update_statistics_metadata_done() -> None: + hass.loop.call_soon_threadsafe(done_event.set) + get_instance(hass).async_update_statistics_metadata( - msg["statistic_id"], new_unit_of_measurement=msg["unit_of_measurement"] + msg["statistic_id"], + new_unit_of_measurement=msg["unit_of_measurement"], + on_done=update_statistics_metadata_done, ) + try: + async with asyncio.timeout(UPDATE_STATISTICS_METADATA_TIME_OUT): + await done_event.wait() + except TimeoutError: + connection.send_error( + msg["id"], websocket_api.ERR_TIMEOUT, "update_statistics_metadata timed out" + ) + return + connection.send_result(msg["id"]) diff --git a/tests/components/recorder/test_websocket_api.py b/tests/components/recorder/test_websocket_api.py index badf25406548b4..70ad3358430b9c 100644 --- a/tests/components/recorder/test_websocket_api.py +++ b/tests/components/recorder/test_websocket_api.py @@ -2216,6 +2216,31 @@ async def test_update_statistics_metadata( } +async def test_update_statistics_metadata_time_out( + recorder_mock: Recorder, hass: HomeAssistant, hass_ws_client: WebSocketGenerator +) -> None: + """Test update statistics metadata with time-out error.""" + client = await hass_ws_client() + + with ( + patch.object(recorder.tasks.UpdateStatisticsMetadataTask, "run"), + patch.object(recorder.websocket_api, "UPDATE_STATISTICS_METADATA_TIME_OUT", 0), + ): + await client.send_json_auto_id( + { + "type": "recorder/update_statistics_metadata", + "statistic_id": "sensor.test", + "unit_of_measurement": "dogs", + } + ) + response = await client.receive_json() + assert not response["success"] + assert response["error"] == { + "code": "timeout", + "message": "update_statistics_metadata timed out", + } + + async def test_change_statistics_unit( recorder_mock: Recorder, hass: HomeAssistant, hass_ws_client: WebSocketGenerator ) -> None: