diff --git a/.strict-typing b/.strict-typing index af4bd4a9cf482..035218d80240a 100644 --- a/.strict-typing +++ b/.strict-typing @@ -320,6 +320,7 @@ homeassistant.components.plugwise.* homeassistant.components.poolsense.* homeassistant.components.powerwall.* homeassistant.components.private_ble_device.* +homeassistant.components.prometheus.* homeassistant.components.proximity.* homeassistant.components.prusalink.* homeassistant.components.pure_energie.* diff --git a/homeassistant/components/prometheus/__init__.py b/homeassistant/components/prometheus/__init__.py index 308bbb599eaf2..e17ae1190a48c 100644 --- a/homeassistant/components/prometheus/__init__.py +++ b/homeassistant/components/prometheus/__init__.py @@ -1,10 +1,15 @@ """Support for Prometheus metrics export.""" +from __future__ import annotations + +from collections.abc import Callable from contextlib import suppress import logging import string +from typing import Any, TypeVar, cast from aiohttp import web import prometheus_client +from prometheus_client.metrics import MetricWrapperBase import voluptuous as vol from homeassistant import core as hacore @@ -40,15 +45,20 @@ STATE_UNKNOWN, UnitOfTemperature, ) -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, State from homeassistant.helpers import entityfilter, state as state_helper import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED +from homeassistant.helpers.entity_registry import ( + EVENT_ENTITY_REGISTRY_UPDATED, + EventEntityRegistryUpdatedData, +) from homeassistant.helpers.entity_values import EntityValues -from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.event import EventStateChangedData +from homeassistant.helpers.typing import ConfigType, EventType from homeassistant.util.dt import as_timestamp from homeassistant.util.unit_conversion import TemperatureConverter +_MetricBaseT = TypeVar("_MetricBaseT", bound=MetricWrapperBase) _LOGGER = logging.getLogger(__name__) API_ENDPOINT = "/api/prometheus" @@ -97,12 +107,12 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool: """Activate Prometheus component.""" hass.http.register_view(PrometheusView(config[DOMAIN][CONF_REQUIRES_AUTH])) - conf = config[DOMAIN] - entity_filter = conf[CONF_FILTER] - namespace = conf.get(CONF_PROM_NAMESPACE) + conf: dict[str, Any] = config[DOMAIN] + entity_filter: entityfilter.EntityFilter = conf[CONF_FILTER] + namespace: str = conf[CONF_PROM_NAMESPACE] climate_units = hass.config.units.temperature_unit - override_metric = conf.get(CONF_OVERRIDE_METRIC) - default_metric = conf.get(CONF_DEFAULT_METRIC) + override_metric: str | None = conf.get(CONF_OVERRIDE_METRIC) + default_metric: str | None = conf.get(CONF_DEFAULT_METRIC) component_config = EntityValues( conf[CONF_COMPONENT_CONFIG], conf[CONF_COMPONENT_CONFIG_DOMAIN], @@ -118,9 +128,10 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool: default_metric, ) - hass.bus.listen(EVENT_STATE_CHANGED, metrics.handle_state_changed_event) + hass.bus.listen(EVENT_STATE_CHANGED, metrics.handle_state_changed_event) # type: ignore[arg-type] hass.bus.listen( - EVENT_ENTITY_REGISTRY_UPDATED, metrics.handle_entity_registry_updated + EVENT_ENTITY_REGISTRY_UPDATED, + metrics.handle_entity_registry_updated, # type: ignore[arg-type] ) for state in hass.states.all(): @@ -135,19 +146,21 @@ class PrometheusMetrics: def __init__( self, - entity_filter, - namespace, - climate_units, - component_config, - override_metric, - default_metric, - ): + entity_filter: entityfilter.EntityFilter, + namespace: str, + climate_units: UnitOfTemperature, + component_config: EntityValues, + override_metric: str | None, + default_metric: str | None, + ) -> None: """Initialize Prometheus Metrics.""" self._component_config = component_config self._override_metric = override_metric self._default_metric = default_metric self._filter = entity_filter - self._sensor_metric_handlers = [ + self._sensor_metric_handlers: list[ + Callable[[State, str | None], str | None] + ] = [ self._sensor_override_component_metric, self._sensor_override_metric, self._sensor_timestamp_metric, @@ -160,10 +173,12 @@ def __init__( self.metrics_prefix = f"{namespace}_" else: self.metrics_prefix = "" - self._metrics = {} + self._metrics: dict[str, MetricWrapperBase] = {} self._climate_units = climate_units - def handle_state_changed_event(self, event): + def handle_state_changed_event( + self, event: EventType[EventStateChangedData] + ) -> None: """Handle new messages from the bus.""" if (state := event.data.get("new_state")) is None: return @@ -179,7 +194,7 @@ def handle_state_changed_event(self, event): self.handle_state(state) - def handle_state(self, state): + def handle_state(self, state: State) -> None: """Add/update a state in Prometheus.""" entity_id = state.entity_id _LOGGER.debug("Handling state update for %s", entity_id) @@ -212,20 +227,22 @@ def handle_state(self, state): ) last_updated_time_seconds.labels(**labels).set(state.last_updated.timestamp()) - def handle_entity_registry_updated(self, event): + def handle_entity_registry_updated( + self, event: EventType[EventEntityRegistryUpdatedData] + ) -> None: """Listen for deleted, disabled or renamed entities and remove them from the Prometheus Registry.""" - if (action := event.data.get("action")) in (None, "create"): + if event.data["action"] in (None, "create"): return entity_id = event.data.get("entity_id") _LOGGER.debug("Handling entity update for %s", entity_id) - metrics_entity_id = None + metrics_entity_id: str | None = None - if action == "remove": + if event.data["action"] == "remove": metrics_entity_id = entity_id - elif action == "update": - changes = event.data.get("changes") + elif event.data["action"] == "update": + changes = event.data["changes"] if "entity_id" in changes: metrics_entity_id = changes["entity_id"] @@ -235,10 +252,14 @@ def handle_entity_registry_updated(self, event): if metrics_entity_id: self._remove_labelsets(metrics_entity_id) - def _remove_labelsets(self, entity_id, friendly_name=None): + def _remove_labelsets( + self, entity_id: str, friendly_name: str | None = None + ) -> None: """Remove labelsets matching the given entity id from all metrics.""" for _, metric in self._metrics.items(): - for sample in metric.collect()[0].samples: + for sample in cast(list[prometheus_client.Metric], metric.collect())[ + 0 + ].samples: if sample.labels["entity"] == entity_id and ( not friendly_name or sample.labels["friendly_name"] == friendly_name ): @@ -250,7 +271,7 @@ def _remove_labelsets(self, entity_id, friendly_name=None): with suppress(KeyError): metric.remove(*sample.labels.values()) - def _handle_attributes(self, state): + def _handle_attributes(self, state: State) -> None: for key, value in state.attributes.items(): metric = self._metric( f"{state.domain}_attr_{key.lower()}", @@ -264,13 +285,19 @@ def _handle_attributes(self, state): except (ValueError, TypeError): pass - def _metric(self, metric, factory, documentation, extra_labels=None): + def _metric( + self, + metric: str, + factory: type[_MetricBaseT], + documentation: str, + extra_labels: list[str] | None = None, + ) -> _MetricBaseT: labels = ["entity", "friendly_name", "domain"] if extra_labels is not None: labels.extend(extra_labels) try: - return self._metrics[metric] + return cast(_MetricBaseT, self._metrics[metric]) except KeyError: full_metric_name = self._sanitize_metric_name( f"{self.metrics_prefix}{metric}" @@ -281,7 +308,7 @@ def _metric(self, metric, factory, documentation, extra_labels=None): labels, registry=prometheus_client.REGISTRY, ) - return self._metrics[metric] + return cast(_MetricBaseT, self._metrics[metric]) @staticmethod def _sanitize_metric_name(metric: str) -> str: @@ -298,7 +325,7 @@ def _sanitize_metric_name(metric: str) -> str: ) @staticmethod - def state_as_number(state): + def state_as_number(state: State) -> float: """Return a state casted to a float.""" try: if state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.TIMESTAMP: @@ -311,14 +338,14 @@ def state_as_number(state): return value @staticmethod - def _labels(state): + def _labels(state: State) -> dict[str, Any]: return { "entity": state.entity_id, "domain": state.domain, "friendly_name": state.attributes.get(ATTR_FRIENDLY_NAME), } - def _battery(self, state): + def _battery(self, state: State) -> None: if (battery_level := state.attributes.get(ATTR_BATTERY_LEVEL)) is not None: metric = self._metric( "battery_level_percent", @@ -331,7 +358,7 @@ def _battery(self, state): except ValueError: pass - def _handle_binary_sensor(self, state): + def _handle_binary_sensor(self, state: State) -> None: metric = self._metric( "binary_sensor_state", prometheus_client.Gauge, @@ -340,7 +367,7 @@ def _handle_binary_sensor(self, state): value = self.state_as_number(state) metric.labels(**self._labels(state)).set(value) - def _handle_input_boolean(self, state): + def _handle_input_boolean(self, state: State) -> None: metric = self._metric( "input_boolean_state", prometheus_client.Gauge, @@ -349,7 +376,7 @@ def _handle_input_boolean(self, state): value = self.state_as_number(state) metric.labels(**self._labels(state)).set(value) - def _numeric_handler(self, state, domain, title): + def _numeric_handler(self, state: State, domain: str, title: str) -> None: if unit := self._unit_string(state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)): metric = self._metric( f"{domain}_state_{unit}", @@ -374,13 +401,13 @@ def _numeric_handler(self, state, domain, title): ) metric.labels(**self._labels(state)).set(value) - def _handle_input_number(self, state): + def _handle_input_number(self, state: State) -> None: self._numeric_handler(state, "input_number", "input number") - def _handle_number(self, state): + def _handle_number(self, state: State) -> None: self._numeric_handler(state, "number", "number") - def _handle_device_tracker(self, state): + def _handle_device_tracker(self, state: State) -> None: metric = self._metric( "device_tracker_state", prometheus_client.Gauge, @@ -389,14 +416,14 @@ def _handle_device_tracker(self, state): value = self.state_as_number(state) metric.labels(**self._labels(state)).set(value) - def _handle_person(self, state): + def _handle_person(self, state: State) -> None: metric = self._metric( "person_state", prometheus_client.Gauge, "State of the person (0/1)" ) value = self.state_as_number(state) metric.labels(**self._labels(state)).set(value) - def _handle_cover(self, state): + def _handle_cover(self, state: State) -> None: metric = self._metric( "cover_state", prometheus_client.Gauge, @@ -428,7 +455,7 @@ def _handle_cover(self, state): ) tilt_position_metric.labels(**self._labels(state)).set(float(tilt_position)) - def _handle_light(self, state): + def _handle_light(self, state: State) -> None: metric = self._metric( "light_brightness_percent", prometheus_client.Gauge, @@ -446,14 +473,16 @@ def _handle_light(self, state): except ValueError: pass - def _handle_lock(self, state): + def _handle_lock(self, state: State) -> None: metric = self._metric( "lock_state", prometheus_client.Gauge, "State of the lock (0/1)" ) value = self.state_as_number(state) metric.labels(**self._labels(state)).set(value) - def _handle_climate_temp(self, state, attr, metric_name, metric_description): + def _handle_climate_temp( + self, state: State, attr: str, metric_name: str, metric_description: str + ) -> None: if (temp := state.attributes.get(attr)) is not None: if self._climate_units == UnitOfTemperature.FAHRENHEIT: temp = TemperatureConverter.convert( @@ -466,7 +495,7 @@ def _handle_climate_temp(self, state, attr, metric_name, metric_description): ) metric.labels(**self._labels(state)).set(temp) - def _handle_climate(self, state): + def _handle_climate(self, state: State) -> None: self._handle_climate_temp( state, ATTR_TEMPERATURE, @@ -518,7 +547,7 @@ def _handle_climate(self, state): float(mode == current_mode) ) - def _handle_humidifier(self, state): + def _handle_humidifier(self, state: State) -> None: humidifier_target_humidity_percent = state.attributes.get(ATTR_HUMIDITY) if humidifier_target_humidity_percent: metric = self._metric( @@ -553,7 +582,7 @@ def _handle_humidifier(self, state): float(mode == current_mode) ) - def _handle_sensor(self, state): + def _handle_sensor(self, state: State) -> None: unit = self._unit_string(state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)) for metric_handler in self._sensor_metric_handlers: @@ -583,12 +612,12 @@ def _handle_sensor(self, state): self._battery(state) - def _sensor_default_metric(self, state, unit): + def _sensor_default_metric(self, state: State, unit: str | None) -> str | None: """Get default metric.""" return self._default_metric @staticmethod - def _sensor_attribute_metric(state, unit): + def _sensor_attribute_metric(state: State, unit: str | None) -> str | None: """Get metric based on device class attribute.""" metric = state.attributes.get(ATTR_DEVICE_CLASS) if metric is not None: @@ -596,25 +625,27 @@ def _sensor_attribute_metric(state, unit): return None @staticmethod - def _sensor_timestamp_metric(state, unit): + def _sensor_timestamp_metric(state: State, unit: str | None) -> str | None: """Get metric for timestamp sensors, which have no unit of measurement attribute.""" metric = state.attributes.get(ATTR_DEVICE_CLASS) if metric == SensorDeviceClass.TIMESTAMP: return f"sensor_{metric}_seconds" return None - def _sensor_override_metric(self, state, unit): + def _sensor_override_metric(self, state: State, unit: str | None) -> str | None: """Get metric from override in configuration.""" if self._override_metric: return self._override_metric return None - def _sensor_override_component_metric(self, state, unit): + def _sensor_override_component_metric( + self, state: State, unit: str | None + ) -> str | None: """Get metric from override in component confioguration.""" return self._component_config.get(state.entity_id).get(CONF_OVERRIDE_METRIC) @staticmethod - def _sensor_fallback_metric(state, unit): + def _sensor_fallback_metric(state: State, unit: str | None) -> str | None: """Get metric from fallback logic for compatibility.""" if unit in (None, ""): try: @@ -626,10 +657,10 @@ def _sensor_fallback_metric(state, unit): return f"sensor_unit_{unit}" @staticmethod - def _unit_string(unit): + def _unit_string(unit: str | None) -> str | None: """Get a formatted string of the unit.""" if unit is None: - return + return None units = { UnitOfTemperature.CELSIUS: "celsius", @@ -640,7 +671,7 @@ def _unit_string(unit): default = default.lower() return units.get(unit, default) - def _handle_switch(self, state): + def _handle_switch(self, state: State) -> None: metric = self._metric( "switch_state", prometheus_client.Gauge, "State of the switch (0/1)" ) @@ -653,10 +684,10 @@ def _handle_switch(self, state): self._handle_attributes(state) - def _handle_zwave(self, state): + def _handle_zwave(self, state: State) -> None: self._battery(state) - def _handle_automation(self, state): + def _handle_automation(self, state: State) -> None: metric = self._metric( "automation_triggered_count", prometheus_client.Counter, @@ -665,7 +696,7 @@ def _handle_automation(self, state): metric.labels(**self._labels(state)).inc() - def _handle_counter(self, state): + def _handle_counter(self, state: State) -> None: metric = self._metric( "counter_value", prometheus_client.Gauge, @@ -674,7 +705,7 @@ def _handle_counter(self, state): metric.labels(**self._labels(state)).set(self.state_as_number(state)) - def _handle_update(self, state): + def _handle_update(self, state: State) -> None: metric = self._metric( "update_state", prometheus_client.Gauge, @@ -694,7 +725,7 @@ def __init__(self, requires_auth: bool) -> None: """Initialize Prometheus view.""" self.requires_auth = requires_auth - async def get(self, request): + async def get(self, request: web.Request) -> web.Response: """Handle request for Prometheus metrics.""" _LOGGER.debug("Received Prometheus metrics request") diff --git a/mypy.ini b/mypy.ini index ce0b4a3575cc6..4b74ad0608f95 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2961,6 +2961,16 @@ disallow_untyped_defs = true warn_return_any = true warn_unreachable = true +[mypy-homeassistant.components.prometheus.*] +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +warn_return_any = true +warn_unreachable = true + [mypy-homeassistant.components.proximity.*] check_untyped_defs = true disallow_incomplete_defs = true