Skip to content

Commit

Permalink
Add check for typed ConfigEntry in quality scale validation (home-ass…
Browse files Browse the repository at this point in the history
  • Loading branch information
epenet authored Dec 10, 2024
1 parent 46d4081 commit 95107cf
Show file tree
Hide file tree
Showing 12 changed files with 101 additions and 22 deletions.
11 changes: 7 additions & 4 deletions script/hassfest/quality_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,16 +1348,19 @@ def validate_iqs_file(config: Config, integration: Integration) -> None:
"quality_scale", f"Invalid {name}: {humanize_error(data, err)}"
)

rules_done = set[str]()
rules_met = set[str]()
for rule_name, rule_value in data.get("rules", {}).items():
status = rule_value["status"] if isinstance(rule_value, dict) else rule_value
if status not in {"done", "exempt"}:
continue
rules_met.add(rule_name)
if (
status == "done"
and (validator := VALIDATORS.get(rule_name))
and (errors := validator.validate(integration))
if status == "done":
rules_done.add(rule_name)

for rule_name in rules_done:
if (validator := VALIDATORS.get(rule_name)) and (
errors := validator.validate(integration, rules_done=rules_done)
):
for error in errors:
integration.add_error("quality_scale", f"[{rule_name}] {error}")
Expand Down
4 changes: 3 additions & 1 deletion script/hassfest/quality_scale_validation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
class RuleValidationProtocol(Protocol):
"""Protocol for rule validation."""

def validate(self, integration: Integration) -> list[str] | None:
def validate(
self, integration: Integration, *, rules_done: set[str]
) -> list[str] | None:
"""Validate a quality scale rule.
Returns error (if any).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _has_unload_entry_function(module: ast.Module) -> bool:
)


def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration has a config flow."""

init_file = integration.path / "__init__.py"
Expand Down
2 changes: 1 addition & 1 deletion script/hassfest/quality_scale_validation/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from script.hassfest.model import Integration


def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration implements config flow."""

if not integration.config_flow:
Expand Down
2 changes: 1 addition & 1 deletion script/hassfest/quality_scale_validation/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _has_diagnostics_function(module: ast.Module) -> bool:
)


def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration implements diagnostics."""

diagnostics_file = integration.path / "diagnostics.py"
Expand Down
2 changes: 1 addition & 1 deletion script/hassfest/quality_scale_validation/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _has_discovery_function(module: ast.Module) -> bool:
)


def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration implements diagnostics."""

config_flow_file = integration.path / "config_flow.py"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _has_parallel_updates_defined(module: ast.Module) -> bool:
)


def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration sets PARALLEL_UPDATES constant."""

errors = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _has_step_reauth_function(module: ast.Module) -> bool:
)


def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration has a reauthentication flow."""

config_flow_file = integration.path / "config_flow.py"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _has_step_reconfigure_function(module: ast.Module) -> bool:
)


def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration has a reconfiguration flow."""

config_flow_file = integration.path / "config_flow.py"
Expand Down
90 changes: 82 additions & 8 deletions script/hassfest/quality_scale_validation/runtime_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,31 @@
"""

import ast
import re

from homeassistant.const import Platform
from script.hassfest import ast_parse_module
from script.hassfest.model import Integration

_ANNOTATION_MATCH = re.compile(r"^[A-Za-z]+ConfigEntry$")
_FUNCTIONS: dict[str, dict[str, int]] = {
"__init__": { # based on ComponentProtocol
"async_migrate_entry": 2,
"async_remove_config_entry_device": 2,
"async_remove_entry": 2,
"async_setup_entry": 2,
"async_unload_entry": 2,
},
"diagnostics": { # based on DiagnosticsProtocol
"async_get_config_entry_diagnostics": 2,
"async_get_device_diagnostics": 2,
},
}
for platform in Platform: # based on EntityPlatformModule
_FUNCTIONS[platform.value] = {
"async_setup_entry": 2,
}


def _sets_runtime_data(
async_setup_entry_function: ast.AsyncFunctionDef, config_entry_argument: ast.arg
Expand All @@ -25,30 +46,83 @@ def _sets_runtime_data(
return False


def _get_setup_entry_function(module: ast.Module) -> ast.AsyncFunctionDef | None:
"""Get async_setup_entry function."""
def _get_async_function(module: ast.Module, name: str) -> ast.AsyncFunctionDef | None:
"""Get async function."""
for item in module.body:
if isinstance(item, ast.AsyncFunctionDef) and item.name == "async_setup_entry":
if isinstance(item, ast.AsyncFunctionDef) and item.name == name:
return item
return None


def validate(integration: Integration) -> list[str] | None:
def _check_function_annotation(
function: ast.AsyncFunctionDef, position: int
) -> str | None:
"""Ensure function uses CustomConfigEntry type annotation."""
if len(function.args.args) < position:
return f"{function.name} has incorrect signature"
argument = function.args.args[position - 1]
if not (
(annotation := argument.annotation)
and isinstance(annotation, ast.Name)
and _ANNOTATION_MATCH.match(annotation.id)
):
return f"([+ strict-typing]) {function.name} does not use typed ConfigEntry"
return None


def _check_typed_config_entry(integration: Integration) -> list[str]:
"""Ensure integration uses CustomConfigEntry type annotation."""
errors: list[str] = []
# Check body level function annotations
for file, functions in _FUNCTIONS.items():
module_file = integration.path / f"{file}.py"
if not module_file.exists():
continue
module = ast_parse_module(module_file)
for function, position in functions.items():
if not (async_function := _get_async_function(module, function)):
continue
if error := _check_function_annotation(async_function, position):
errors.append(f"{error} in {module_file}")

# Check config_flow annotations
config_flow_file = integration.path / "config_flow.py"
config_flow = ast_parse_module(config_flow_file)
for node in config_flow.body:
if not isinstance(node, ast.ClassDef):
continue
if any(
isinstance(async_function, ast.FunctionDef)
and async_function.name == "async_get_options_flow"
and (error := _check_function_annotation(async_function, 1))
for async_function in node.body
):
errors.append(f"{error} in {config_flow_file}")

return errors


def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate correct use of ConfigEntry.runtime_data."""
init_file = integration.path / "__init__.py"
init = ast_parse_module(init_file)

# Should not happen, but better to be safe
if not (async_setup_entry := _get_setup_entry_function(init)):
if not (async_setup_entry := _get_async_function(init, "async_setup_entry")):
return [f"Could not find `async_setup_entry` in {init_file}"]
if len(async_setup_entry.args.args) != 2:
return [f"async_setup_entry has incorrect signature in {init_file}"]
config_entry_argument = async_setup_entry.args.args[1]

errors: list[str] = []
if not _sets_runtime_data(async_setup_entry, config_entry_argument):
return [
errors.append(
"Integration does not set entry.runtime_data in async_setup_entry"
f"({init_file})"
]
)

return None
# Extra checks, if strict-typing is marked as done
if "strict-typing" in rules_done:
errors.extend(_check_typed_config_entry(integration))

return errors
2 changes: 1 addition & 1 deletion script/hassfest/quality_scale_validation/strict_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _strict_typing_components() -> set[str]:
)


def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration has strict typing enabled."""

if integration.domain not in _strict_typing_components():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _has_abort_unique_id_configured(module: ast.Module) -> bool:
)


def validate(integration: Integration) -> list[str] | None:
def validate(integration: Integration, *, rules_done: set[str]) -> list[str] | None:
"""Validate that the integration prevents duplicate devices."""

if integration.manifest.get("single_config_entry"):
Expand Down

0 comments on commit 95107cf

Please sign in to comment.