Skip to content

Commit

Permalink
Add additional CalendarEvent validation (home-assistant#89533)
Browse files Browse the repository at this point in the history
Add additional event validation
  • Loading branch information
allenporter authored Mar 15, 2023
1 parent c33ca4f commit 4ddcb14
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 82 deletions.
6 changes: 6 additions & 0 deletions homeassistant/components/caldav/calendar.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,4 +356,10 @@ def get_end_date(obj):
else:
enddate = obj.dtstart.value + timedelta(days=1)

# End date for an all day event is exclusive. This fixes the case where
# an all day event has a start and end values are the same, or the event
# has a zero duration.
if not isinstance(enddate, datetime) and obj.dtstart.value == enddate:
enddate += timedelta(days=1)

return enddate
194 changes: 115 additions & 79 deletions homeassistant/components/calendar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,23 @@
VALID_FREQS = {"DAILY", "WEEKLY", "MONTHLY", "YEARLY"}


def _has_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Assert that all datetime values have a timezone."""

def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Validate that all datetime values have a timezone."""
for k in keys:
if (
(value := obj.get(k))
and isinstance(value, datetime.datetime)
and value.tzinfo is None
):
raise vol.Invalid("Expected all values to have a timezone")
return obj

return validate


def _has_consistent_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that all datetime values have a consistent timezone."""

Expand All @@ -89,7 +106,7 @@ def _as_local_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]
"""Convert all datetime values to the local timezone."""

def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys that are datetime values have the same timezone."""
"""Convert all keys that are datetime values to local timezone."""
for k in keys:
if (value := obj.get(k)) and isinstance(value, datetime.datetime):
obj[k] = dt.as_local(value)
Expand All @@ -98,23 +115,59 @@ def validate(obj: dict[str, Any]) -> dict[str, Any]:
return validate


def _is_sorted(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that the specified values are sequential."""
def _has_duration(
start_key: str, end_key: str
) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that the time span between start and end is positive."""

def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys in the dict are in order."""
values = []
for k in keys:
if not (value := obj.get(k)):
return obj
values.append(value)
if all(values) and values != sorted(values):
raise vol.Invalid(f"Values were not in order: {values}")
if (start := obj.get(start_key)) and (end := obj.get(end_key)):
duration = end - start
if duration.total_seconds() <= 0:
raise vol.Invalid(f"Expected positive event duration ({start}, {end})")
return obj

return validate


def _has_same_type(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that all values are of the same type."""

def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys in the dict have values of the same type."""
uniq_values = groupby(type(obj[k]) for k in keys)
if len(list(uniq_values)) > 1:
raise vol.Invalid(f"Expected all values to be the same type: {keys}")
return obj

return validate


def _validate_rrule(value: Any) -> str:
"""Validate a recurrence rule string."""
if value is None:
raise vol.Invalid("rrule value is None")

if not isinstance(value, str):
raise vol.Invalid("rrule value expected a string")

try:
rrulestr(value)
except ValueError as err:
raise vol.Invalid(f"Invalid rrule: {str(err)}") from err

# Example format: FREQ=DAILY;UNTIL=...
rule_parts = dict(s.split("=", 1) for s in value.split(";"))
if not (freq := rule_parts.get("FREQ")):
raise vol.Invalid("rrule did not contain FREQ")

if freq not in VALID_FREQS:
raise vol.Invalid(f"Invalid frequency for rule: {value}")

return str(value)


CREATE_EVENT_SERVICE = "create_event"
CREATE_EVENT_SCHEMA = vol.All(
cv.has_at_least_one_key(EVENT_START_DATE, EVENT_START_DATETIME, EVENT_IN),
Expand Down Expand Up @@ -149,8 +202,42 @@ def validate(obj: dict[str, Any]) -> dict[str, Any]:
),
_has_consistent_timezone(EVENT_START_DATETIME, EVENT_END_DATETIME),
_as_local_timezone(EVENT_START_DATETIME, EVENT_END_DATETIME),
_is_sorted(EVENT_START_DATE, EVENT_END_DATE),
_is_sorted(EVENT_START_DATETIME, EVENT_END_DATETIME),
_has_duration(EVENT_START_DATE, EVENT_END_DATE),
_has_duration(EVENT_START_DATETIME, EVENT_END_DATETIME),
)

WEBSOCKET_EVENT_SCHEMA = vol.Schema(
vol.All(
{
vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_SUMMARY): cv.string,
vol.Optional(EVENT_DESCRIPTION): cv.string,
vol.Optional(EVENT_RRULE): _validate_rrule,
},
_has_same_type(EVENT_START, EVENT_END),
_has_consistent_timezone(EVENT_START, EVENT_END),
_as_local_timezone(EVENT_START, EVENT_END),
_has_duration(EVENT_START, EVENT_END),
)
)

# Validation for the CalendarEvent dataclass
CALENDAR_EVENT_SCHEMA = vol.Schema(
vol.All(
{
vol.Required("start"): vol.Any(cv.date, cv.datetime),
vol.Required("end"): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_SUMMARY): cv.string,
vol.Optional(EVENT_RRULE): _validate_rrule,
},
_has_same_type("start", "end"),
_has_timezone("start", "end"),
_has_consistent_timezone("start", "end"),
_as_local_timezone("start", "end"),
_has_duration("start", "end"),
),
extra=vol.ALLOW_EXTRA,
)


Expand Down Expand Up @@ -243,6 +330,19 @@ def as_dict(self) -> dict[str, Any]:
"all_day": self.all_day,
}

def __post_init__(self) -> None:
"""Perform validation on the CalendarEvent."""

def skip_none(obj: Iterable[tuple[str, Any]]) -> dict[str, str]:
return {k: v for k, v in obj if v is not None}

try:
CALENDAR_EVENT_SCHEMA(dataclasses.asdict(self, dict_factory=skip_none))
except vol.Invalid as err:
raise HomeAssistantError(
f"Failed to validate CalendarEvent: {err}"
) from err


def _event_dict_factory(obj: Iterable[tuple[str, Any]]) -> dict[str, str]:
"""Convert CalendarEvent dataclass items to dictionary of attributes."""
Expand Down Expand Up @@ -316,30 +416,6 @@ def is_offset_reached(
return start + offset_time <= dt.now(start.tzinfo)


def _validate_rrule(value: Any) -> str:
"""Validate a recurrence rule string."""
if value is None:
raise vol.Invalid("rrule value is None")

if not isinstance(value, str):
raise vol.Invalid("rrule value expected a string")

try:
rrulestr(value)
except ValueError as err:
raise vol.Invalid(f"Invalid rrule: {str(err)}") from err

# Example format: FREQ=DAILY;UNTIL=...
rule_parts = dict(s.split("=", 1) for s in value.split(";"))
if not (freq := rule_parts.get("FREQ")):
raise vol.Invalid("rrule did not contain FREQ")

if freq not in VALID_FREQS:
raise vol.Invalid(f"Invalid frequency for rule: {value}")

return str(value)


class CalendarEntity(Entity):
"""Base class for calendar event entities."""

Expand Down Expand Up @@ -447,6 +523,7 @@ async def get(self, request: web.Request, entity_id: str) -> web.Response:
request.app["hass"], start_date, end_date
)
except HomeAssistantError as err:
_LOGGER.debug("Error reading events: %s", err)
return self.json_message(
f"Error reading events: {err}", HTTPStatus.INTERNAL_SERVER_ERROR
)
Expand Down Expand Up @@ -481,38 +558,11 @@ async def get(self, request: web.Request) -> web.Response:
return self.json(sorted(calendar_list, key=lambda x: cast(str, x["name"])))


def _has_same_type(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that all values are of the same type."""

def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys in the dict have values of the same type."""
uniq_values = groupby(type(obj[k]) for k in keys)
if len(list(uniq_values)) > 1:
raise vol.Invalid(f"Expected all values to be the same type: {keys}")
return obj

return validate


@websocket_api.websocket_command(
{
vol.Required("type"): "calendar/event/create",
vol.Required("entity_id"): cv.entity_id,
CONF_EVENT: vol.Schema(
vol.All(
{
vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_SUMMARY): cv.string,
vol.Optional(EVENT_DESCRIPTION): cv.string,
vol.Optional(EVENT_RRULE): _validate_rrule,
},
_has_same_type(EVENT_START, EVENT_END),
_has_consistent_timezone(EVENT_START, EVENT_END),
_as_local_timezone(EVENT_START, EVENT_END),
_is_sorted(EVENT_START, EVENT_END),
)
),
CONF_EVENT: WEBSOCKET_EVENT_SCHEMA,
}
)
@websocket_api.async_response
Expand Down Expand Up @@ -595,21 +645,7 @@ async def handle_calendar_event_delete(
vol.Required(EVENT_UID): cv.string,
vol.Optional(EVENT_RECURRENCE_ID): cv.string,
vol.Optional(EVENT_RECURRENCE_RANGE): cv.string,
vol.Required(CONF_EVENT): vol.Schema(
vol.All(
{
vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_SUMMARY): cv.string,
vol.Optional(EVENT_DESCRIPTION): cv.string,
vol.Optional(EVENT_RRULE): _validate_rrule,
},
_has_same_type(EVENT_START, EVENT_END),
_has_consistent_timezone(EVENT_START, EVENT_END),
_as_local_timezone(EVENT_START, EVENT_END),
_is_sorted(EVENT_START, EVENT_END),
)
),
vol.Required(CONF_EVENT): WEBSOCKET_EVENT_SCHEMA,
}
)
@websocket_api.async_response
Expand Down
13 changes: 11 additions & 2 deletions tests/components/calendar/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,23 @@ async def test_unsupported_create_event_service(hass: HomeAssistant) -> None:
"end_date_time": "2022-04-01T06:00:00",
},
vol.error.MultipleInvalid,
"Values were not in order",
"Expected positive event duration",
),
(
{
"start_date": "2022-04-02",
"end_date": "2022-04-01",
},
vol.error.MultipleInvalid,
"Values were not in order",
"Expected positive event duration",
),
(
{
"start_date": "2022-04-01",
"end_date": "2022-04-01",
},
vol.error.MultipleInvalid,
"Expected positive event duration",
),
],
ids=[
Expand All @@ -351,6 +359,7 @@ async def test_unsupported_create_event_service(hass: HomeAssistant) -> None:
"inconsistent_timezone",
"incorrect_date_order",
"incorrect_datetime_order",
"dates_not_exclusive",
],
)
async def test_create_event_service_invalid_params(
Expand Down
2 changes: 1 addition & 1 deletion tests/components/google/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ async def test_add_event_failure(

with pytest.raises(HomeAssistantError):
await add_event_call_service(
{"start_date": "2022-05-01", "end_date": "2022-05-01"}
{"start_date": "2022-05-01", "end_date": "2022-05-02"}
)


Expand Down

0 comments on commit 4ddcb14

Please sign in to comment.