Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

History API entity_id validation #90067

Merged
merged 9 commits into from
Apr 14, 2023
12 changes: 9 additions & 3 deletions homeassistant/components/history/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.recorder import get_instance, history
from homeassistant.components.recorder.util import session_scope
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, valid_entity_id
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA
from homeassistant.helpers.typing import ConfigType
Expand Down Expand Up @@ -73,6 +73,14 @@ async def get(
"filter_entity_id is missing", HTTPStatus.BAD_REQUEST
)

hass = request.app["hass"]

for entity_id in entity_ids:
if not hass.states.get(entity_id) and not valid_entity_id(entity_id):
return self.json_message(
"Invalid filter_entity_id", HTTPStatus.BAD_REQUEST
)

now = dt_util.utcnow()
if datetime_:
start_time = dt_util.as_utc(datetime_)
Expand All @@ -96,8 +104,6 @@ async def get(
minimal_response = "minimal_response" in request.query
no_attributes = "no_attributes" in request.query

hass = request.app["hass"]

if (
not include_start_time_state
and entity_ids
Expand Down
15 changes: 13 additions & 2 deletions homeassistant/components/history/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
State,
callback,
is_callback,
valid_entity_id,
)
from homeassistant.helpers.event import (
async_track_point_in_utc_time,
Expand Down Expand Up @@ -95,7 +96,7 @@ def _ws_get_significant_states(
vol.Required("type"): "history/history_during_period",
vol.Required("start_time"): str,
vol.Optional("end_time"): str,
vol.Optional("entity_ids"): [str],
vol.Required("entity_ids"): [str],
vol.Optional("include_start_time_state", default=True): bool,
vol.Optional("significant_changes_only", default=True): bool,
vol.Optional("minimal_response", default=False): bool,
Expand Down Expand Up @@ -129,7 +130,12 @@ async def ws_get_history_during_period(
connection.send_result(msg["id"], {})
return

entity_ids = msg.get("entity_ids")
entity_ids: list[str] = msg["entity_ids"]
for entity_id in entity_ids:
if not hass.states.get(entity_id) and not valid_entity_id(entity_id):
connection.send_error(msg["id"], "invalid_entity_ids", "Invalid entity_ids")
return

include_start_time_state = msg["include_start_time_state"]
no_attributes = msg["no_attributes"]

Expand Down Expand Up @@ -428,6 +434,11 @@ async def ws_stream(
return

entity_ids: list[str] = msg["entity_ids"]
for entity_id in entity_ids:
if not hass.states.get(entity_id) and not valid_entity_id(entity_id):
connection.send_error(msg["id"], "invalid_entity_ids", "Invalid entity_ids")
return

include_start_time_state = msg["include_start_time_state"]
significant_changes_only = msg["significant_changes_only"]
no_attributes = msg["no_attributes"]
Expand Down
69 changes: 69 additions & 0 deletions tests/components/history/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,3 +696,72 @@ async def test_fetch_period_api_with_no_entity_ids(
assert response.status == HTTPStatus.BAD_REQUEST
response_json = await response.json()
assert response_json == {"message": "filter_entity_id is missing"}


@pytest.mark.parametrize(
("filter_entity_id", "status_code", "response_contains1", "response_contains2"),
[
("light.kitchen,light.cow", HTTPStatus.OK, "light.kitchen", "light.cow"),
(
"light.kitchen,light.cow&",
HTTPStatus.BAD_REQUEST,
"message",
"Invalid filter_entity_id",
),
(
"light.kitchen,li-ght.cow",
HTTPStatus.BAD_REQUEST,
"message",
"Invalid filter_entity_id",
),
(
"light.kit!chen",
HTTPStatus.BAD_REQUEST,
"message",
"Invalid filter_entity_id",
),
(
"lig+ht.kitchen,light.cow",
HTTPStatus.BAD_REQUEST,
"message",
"Invalid filter_entity_id",
),
(
"light.kitchenlight.cow",
HTTPStatus.BAD_REQUEST,
"message",
"Invalid filter_entity_id",
),
("cow", HTTPStatus.BAD_REQUEST, "message", "Invalid filter_entity_id"),
],
)
async def test_history_with_invalid_entity_ids(
filter_entity_id,
status_code,
response_contains1,
response_contains2,
recorder_mock: Recorder,
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
) -> None:
"""Test sending valid and invalid entity_ids to the API."""
await async_setup_component(
hass,
"history",
{"history": {}},
)
hass.states.async_set("light.kitchen", "on")
hass.states.async_set("light.cow", "on")

await async_wait_recording_done(hass)
now = dt_util.utcnow().isoformat()
client = await hass_client()

response = await client.get(
f"/api/history/period/{now}",
params={"filter_entity_id": filter_entity_id},
)
assert response.status == status_code
response_json = await response.json()
assert response_contains1 in str(response_json)
assert response_contains2 in str(response_json)
2 changes: 2 additions & 0 deletions tests/components/history/test_init_db_schema_30.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,7 @@ async def test_history_during_period_bad_start_time(
{
"id": 1,
"type": "history/history_during_period",
"entity_ids": ["sensor.pet"],
"start_time": "cats",
}
)
Expand All @@ -1004,6 +1005,7 @@ async def test_history_during_period_bad_end_time(
{
"id": 1,
"type": "history/history_during_period",
"entity_ids": ["sensor.pet"],
"start_time": now.isoformat(),
"end_time": "dogs",
}
Expand Down
Loading