Skip to content

Commit

Permalink
Improve validation of entity service schemas (home-assistant#124102)
Browse files Browse the repository at this point in the history
* Improve validation of entity service schemas

* Update tests/helpers/test_entity_platform.py

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>

---------

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
  • Loading branch information
emontnemery and joostlek authored Aug 27, 2024
1 parent 0dc1eb8 commit 55c42fd
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 49 deletions.
23 changes: 22 additions & 1 deletion homeassistant/helpers/config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,9 +1305,28 @@ def platform_only_config_schema(domain: str) -> Callable[[dict], dict]:
_HAS_ENTITY_SERVICE_FIELD = has_at_least_one_key(*ENTITY_SERVICE_FIELDS)


def is_entity_service_schema(validator: VolSchemaType) -> bool:
"""Check if the passed validator is an entity schema validator.
The validator must be either of:
- A validator returned by cv._make_entity_service_schema
- A validator returned by cv._make_entity_service_schema, wrapped in a vol.Schema
- A validator returned by cv._make_entity_service_schema, wrapped in a vol.All
Nesting is allowed.
"""
if hasattr(validator, "_entity_service_schema"):
return True
if isinstance(validator, (vol.All)):
return any(is_entity_service_schema(val) for val in validator.validators)
if isinstance(validator, (vol.Schema)):
return is_entity_service_schema(validator.schema)

return False


def _make_entity_service_schema(schema: dict, extra: int) -> VolSchemaType:
"""Create an entity service schema."""
return vol.All(
validator = vol.All(
vol.Schema(
{
# The frontend stores data here. Don't use in core.
Expand All @@ -1319,6 +1338,8 @@ def _make_entity_service_schema(schema: dict, extra: int) -> VolSchemaType:
),
_HAS_ENTITY_SERVICE_FIELD,
)
setattr(validator, "_entity_service_schema", True)
return validator


BASE_ENTITY_SCHEMA = _make_entity_service_schema({}, vol.PREVENT_EXTRA)
Expand Down
13 changes: 2 additions & 11 deletions homeassistant/helpers/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,17 +1267,8 @@ def async_register_entity_service(
# Do a sanity check to check this is a valid entity service schema,
# the check could be extended to require All/Any to have sub schema(s)
# with all entity service fields
elif (
# Don't check All/Any
not isinstance(schema, (vol.All, vol.Any))
# Don't check All/Any wrapped in schema
and not isinstance(schema.schema, (vol.All, vol.Any))
and any(key not in schema.schema for key in cv.ENTITY_SERVICE_FIELDS)
):
raise HomeAssistantError(
"The schema does not include all required keys: "
f"{", ".join(str(key) for key in cv.ENTITY_SERVICE_FIELDS)}"
)
elif not cv.is_entity_service_schema(schema):
raise HomeAssistantError("The schema is not an entity service schema")

service_func: str | HassJob[..., Any]
service_func = func if isinstance(func, str) else HassJob(func)
Expand Down
24 changes: 24 additions & 0 deletions tests/helpers/test_config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,3 +1805,27 @@ def _mock_validator_schema(real_func, *args):
"string": [hass.loop_thread_id],
}
validator_calls = {}


async def test_is_entity_service_schema(
hass: HomeAssistant,
) -> None:
"""Test cv.is_entity_service_schema."""
for schema in (
vol.Schema({"some": str}),
vol.All(vol.Schema({"some": str})),
vol.Any(vol.Schema({"some": str})),
vol.Any(cv.make_entity_service_schema({"some": str})),
):
assert cv.is_entity_service_schema(schema) is False

for schema in (
cv.make_entity_service_schema({"some": str}),
vol.Schema(cv.make_entity_service_schema({"some": str})),
vol.Schema(vol.All(cv.make_entity_service_schema({"some": str}))),
vol.Schema(vol.Schema(cv.make_entity_service_schema({"some": str}))),
vol.All(cv.make_entity_service_schema({"some": str})),
vol.All(vol.All(cv.make_entity_service_schema({"some": str}))),
vol.All(vol.Schema(cv.make_entity_service_schema({"some": str}))),
):
assert cv.is_entity_service_schema(schema) is True
38 changes: 19 additions & 19 deletions tests/helpers/test_entity_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
callback,
)
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.helpers import discovery
from homeassistant.helpers import config_validation as cv, discovery
from homeassistant.helpers.entity_component import EntityComponent, async_update_entity
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
Expand Down Expand Up @@ -559,28 +559,28 @@ def appender(**kwargs):
async def test_register_entity_service_non_entity_service_schema(
hass: HomeAssistant,
) -> None:
"""Test attempting to register a service with an incomplete schema."""
"""Test attempting to register a service with a non entity service schema."""
component = EntityComponent(_LOGGER, DOMAIN, hass)

with pytest.raises(
HomeAssistantError,
match=(
"The schema does not include all required keys: entity_id, device_id, area_id, "
"floor_id, label_id"
),
for schema in (
vol.Schema({"some": str}),
vol.All(vol.Schema({"some": str})),
vol.Any(vol.Schema({"some": str})),
):
component.async_register_entity_service(
"hello", vol.Schema({"some": str}), Mock()
with pytest.raises(
HomeAssistantError,
match=("The schema is not an entity service schema"),
):
component.async_register_entity_service("hello", schema, Mock())

for idx, schema in enumerate(
(
cv.make_entity_service_schema({"some": str}),
vol.Schema(cv.make_entity_service_schema({"some": str})),
vol.All(cv.make_entity_service_schema({"some": str})),
)

# The check currently does not recurse into vol.All or vol.Any allowing these
# non-compliant schemas to pass
component.async_register_entity_service(
"hello", vol.All(vol.Schema({"some": str})), Mock()
)
component.async_register_entity_service(
"hello", vol.Any(vol.Schema({"some": str})), Mock()
)
):
component.async_register_entity_service(f"test_service_{idx}", schema, Mock())


async def test_register_entity_service_response_data(hass: HomeAssistant) -> None:
Expand Down
38 changes: 20 additions & 18 deletions tests/helpers/test_entity_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.helpers import (
area_registry as ar,
config_validation as cv,
device_registry as dr,
entity_platform,
entity_registry as er,
Expand Down Expand Up @@ -1812,31 +1813,32 @@ def handle_service(entity, *_):
async def test_register_entity_service_non_entity_service_schema(
hass: HomeAssistant,
) -> None:
"""Test attempting to register a service with an incomplete schema."""
"""Test attempting to register a service with a non entity service schema."""
entity_platform = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
)

with pytest.raises(
HomeAssistantError,
match=(
"The schema does not include all required keys: entity_id, device_id, area_id, "
"floor_id, label_id"
),
for schema in (
vol.Schema({"some": str}),
vol.All(vol.Schema({"some": str})),
vol.Any(vol.Schema({"some": str})),
):
with pytest.raises(
HomeAssistantError,
match="The schema is not an entity service schema",
):
entity_platform.async_register_entity_service("hello", schema, Mock())

for idx, schema in enumerate(
(
cv.make_entity_service_schema({"some": str}),
vol.Schema(cv.make_entity_service_schema({"some": str})),
vol.All(cv.make_entity_service_schema({"some": str})),
)
):
entity_platform.async_register_entity_service(
"hello",
vol.Schema({"some": str}),
Mock(),
f"test_service_{idx}", schema, Mock()
)
# The check currently does not recurse into vol.All or vol.Any allowing these
# non-compliant schemas to pass
entity_platform.async_register_entity_service(
"hello", vol.All(vol.Schema({"some": str})), Mock()
)
entity_platform.async_register_entity_service(
"hello", vol.Any(vol.Schema({"some": str})), Mock()
)


@pytest.mark.parametrize("update_before_add", [True, False])
Expand Down

0 comments on commit 55c42fd

Please sign in to comment.