Skip to content

Commit

Permalink
History API entity_id validation (#90067)
Browse files Browse the repository at this point in the history
Co-authored-by: J. Nick Koston <nick@koston.org>
  • Loading branch information
flip-dots and bdraco authored Apr 14, 2023
1 parent f5911bc commit bf45597
Show file tree
Hide file tree
Showing 5 changed files with 412 additions and 5 deletions.
12 changes: 9 additions & 3 deletions homeassistant/components/history/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.recorder import get_instance, history
from homeassistant.components.recorder.util import session_scope
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, valid_entity_id
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA
from homeassistant.helpers.typing import ConfigType
Expand Down Expand Up @@ -73,6 +73,14 @@ async def get(
"filter_entity_id is missing", HTTPStatus.BAD_REQUEST
)

hass = request.app["hass"]

for entity_id in entity_ids:
if not hass.states.get(entity_id) and not valid_entity_id(entity_id):
return self.json_message(
"Invalid filter_entity_id", HTTPStatus.BAD_REQUEST
)

now = dt_util.utcnow()
if datetime_:
start_time = dt_util.as_utc(datetime_)
Expand All @@ -96,8 +104,6 @@ async def get(
minimal_response = "minimal_response" in request.query
no_attributes = "no_attributes" in request.query

hass = request.app["hass"]

if (
not include_start_time_state
and entity_ids
Expand Down
15 changes: 13 additions & 2 deletions homeassistant/components/history/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
State,
callback,
is_callback,
valid_entity_id,
)
from homeassistant.helpers.event import (
async_track_point_in_utc_time,
Expand Down Expand Up @@ -95,7 +96,7 @@ def _ws_get_significant_states(
vol.Required("type"): "history/history_during_period",
vol.Required("start_time"): str,
vol.Optional("end_time"): str,
vol.Optional("entity_ids"): [str],
vol.Required("entity_ids"): [str],
vol.Optional("include_start_time_state", default=True): bool,
vol.Optional("significant_changes_only", default=True): bool,
vol.Optional("minimal_response", default=False): bool,
Expand Down Expand Up @@ -129,7 +130,12 @@ async def ws_get_history_during_period(
connection.send_result(msg["id"], {})
return

entity_ids = msg.get("entity_ids")
entity_ids: list[str] = msg["entity_ids"]
for entity_id in entity_ids:
if not hass.states.get(entity_id) and not valid_entity_id(entity_id):
connection.send_error(msg["id"], "invalid_entity_ids", "Invalid entity_ids")
return

include_start_time_state = msg["include_start_time_state"]
no_attributes = msg["no_attributes"]

Expand Down Expand Up @@ -428,6 +434,11 @@ async def ws_stream(
return

entity_ids: list[str] = msg["entity_ids"]
for entity_id in entity_ids:
if not hass.states.get(entity_id) and not valid_entity_id(entity_id):
connection.send_error(msg["id"], "invalid_entity_ids", "Invalid entity_ids")
return

include_start_time_state = msg["include_start_time_state"]
significant_changes_only = msg["significant_changes_only"]
no_attributes = msg["no_attributes"]
Expand Down
69 changes: 69 additions & 0 deletions tests/components/history/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,3 +696,72 @@ async def test_fetch_period_api_with_no_entity_ids(
assert response.status == HTTPStatus.BAD_REQUEST
response_json = await response.json()
assert response_json == {"message": "filter_entity_id is missing"}


@pytest.mark.parametrize(
("filter_entity_id", "status_code", "response_contains1", "response_contains2"),
[
("light.kitchen,light.cow", HTTPStatus.OK, "light.kitchen", "light.cow"),
(
"light.kitchen,light.cow&",
HTTPStatus.BAD_REQUEST,
"message",
"Invalid filter_entity_id",
),
(
"light.kitchen,li-ght.cow",
HTTPStatus.BAD_REQUEST,
"message",
"Invalid filter_entity_id",
),
(
"light.kit!chen",
HTTPStatus.BAD_REQUEST,
"message",
"Invalid filter_entity_id",
),
(
"lig+ht.kitchen,light.cow",
HTTPStatus.BAD_REQUEST,
"message",
"Invalid filter_entity_id",
),
(
"light.kitchenlight.cow",
HTTPStatus.BAD_REQUEST,
"message",
"Invalid filter_entity_id",
),
("cow", HTTPStatus.BAD_REQUEST, "message", "Invalid filter_entity_id"),
],
)
async def test_history_with_invalid_entity_ids(
filter_entity_id,
status_code,
response_contains1,
response_contains2,
recorder_mock: Recorder,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
) -> None:
"""Test sending valid and invalid entity_ids to the API."""
await async_setup_component(
hass,
"history",
{"history": {}},
)
hass.states.async_set("light.kitchen", "on")
hass.states.async_set("light.cow", "on")

await async_wait_recording_done(hass)
now = dt_util.utcnow().isoformat()
client = await hass_client()

response = await client.get(
f"/api/history/period/{now}",
params={"filter_entity_id": filter_entity_id},
)
assert response.status == status_code
response_json = await response.json()
assert response_contains1 in str(response_json)
assert response_contains2 in str(response_json)
2 changes: 2 additions & 0 deletions tests/components/history/test_init_db_schema_30.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,7 @@ async def test_history_during_period_bad_start_time(
{
"id": 1,
"type": "history/history_during_period",
"entity_ids": ["sensor.pet"],
"start_time": "cats",
}
)
Expand All @@ -1004,6 +1005,7 @@ async def test_history_during_period_bad_end_time(
{
"id": 1,
"type": "history/history_during_period",
"entity_ids": ["sensor.pet"],
"start_time": now.isoformat(),
"end_time": "dogs",
}
Expand Down
Loading

0 comments on commit bf45597

Please sign in to comment.