Skip to content

Commit

Permalink
Rework MQTT config merging and adding defaults (home-assistant#90529)
Browse files Browse the repository at this point in the history
* Cleanup config merging and adding defaults

* Optimize and update tests

* Do not mix entry and yaml config

* Make sure hass.data is initilized

* remove check on get_mqtt_data

* Tweaks to MQTT client

* Remove None assigment mqtt client and fix mock
  • Loading branch information
jbouwh authored Apr 4, 2023
1 parent 690a0f3 commit 4a0d3e8
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 180 deletions.
111 changes: 22 additions & 89 deletions homeassistant/components/mqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
SERVICE_RELOAD,
)
from homeassistant.core import HassJob, HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import TemplateError, Unauthorized
from homeassistant.exceptions import ConfigEntryError, TemplateError, Unauthorized
from homeassistant.helpers import config_validation as cv, event, template
from homeassistant.helpers.device_registry import DeviceEntry
from homeassistant.helpers.dispatcher import async_dispatcher_connect
Expand All @@ -45,11 +45,7 @@
publish,
subscribe,
)
from .config_integration import (
CONFIG_SCHEMA_ENTRY,
DEFAULT_VALUES,
PLATFORM_CONFIG_SCHEMA_BASE,
)
from .config_integration import CONFIG_SCHEMA_ENTRY, PLATFORM_CONFIG_SCHEMA_BASE
from .const import ( # noqa: F401
ATTR_PAYLOAD,
ATTR_QOS,
Expand Down Expand Up @@ -83,6 +79,7 @@
)
from .models import ( # noqa: F401
MqttCommandTemplate,
MqttData,
MqttValueTemplate,
PublishPayloadType,
ReceiveMessage,
Expand All @@ -102,8 +99,6 @@
SERVICE_PUBLISH = "publish"
SERVICE_DUMP = "dump"

MANDATORY_DEFAULT_VALUES = (CONF_PORT, CONF_DISCOVERY_PREFIX)

ATTR_TOPIC_TEMPLATE = "topic_template"
ATTR_PAYLOAD_TEMPLATE = "payload_template"

Expand Down Expand Up @@ -193,50 +188,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True


def _filter_entry_config(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Remove unknown keys from config entry data.
Extra keys may have been added when importing MQTT yaml configuration.
"""
filtered_data = {
k: entry.data[k] for k in CONFIG_ENTRY_CONFIG_KEYS if k in entry.data
}
if entry.data.keys() != filtered_data.keys():
_LOGGER.warning(
(
"The following unsupported configuration options were removed from the "
"MQTT config entry: %s"
),
entry.data.keys() - filtered_data.keys(),
)
hass.config_entries.async_update_entry(entry, data=filtered_data)


async def _async_auto_mend_config(
hass: HomeAssistant, entry: ConfigEntry, yaml_config: dict[str, Any]
) -> None:
"""Mends config fetched from config entry and adds missing values.
This mends incomplete migration from old version of HA Core.
"""
entry_updated = False
entry_config = {**entry.data}
for key in MANDATORY_DEFAULT_VALUES:
if key not in entry_config:
entry_config[key] = DEFAULT_VALUES[key]
entry_updated = True

if entry_updated:
hass.config_entries.async_update_entry(entry, data=entry_config)


def _merge_extended_config(entry: ConfigEntry, conf: ConfigType) -> dict[str, Any]:
"""Merge advanced options in configuration.yaml config with config entry."""
# Add default values
conf = {**DEFAULT_VALUES, **conf}
return {**conf, **entry.data}


async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Handle signals of config entry being updated.
Expand All @@ -245,45 +196,29 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -
await hass.config_entries.async_reload(entry.entry_id)


async def async_fetch_config(
hass: HomeAssistant, entry: ConfigEntry
) -> dict[str, Any] | None:
"""Fetch fresh MQTT yaml config from the hass config."""
mqtt_data = get_mqtt_data(hass)
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_data.config = PLATFORM_CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))

# Remove unknown keys from config entry data
_filter_entry_config(hass, entry)

# Add missing defaults to migrate older config entries
await _async_auto_mend_config(hass, entry, mqtt_data.config or {})
# Bail out if broker setting is missing
if CONF_BROKER not in entry.data:
_LOGGER.error("MQTT broker is not configured, please configure it")
return None

# If user doesn't have configuration.yaml config, generate default values
# for options not in config entry data
if (conf := mqtt_data.config) is None:
conf = CONFIG_SCHEMA_ENTRY(dict(entry.data))

# Merge advanced configuration values from configuration.yaml
conf = _merge_extended_config(entry, conf)
return conf


async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Load a config entry."""
mqtt_data = get_mqtt_data(hass, True)
# validate entry config
try:
conf = CONFIG_SCHEMA_ENTRY(dict(entry.data))
except vol.MultipleInvalid as ex:
raise ConfigEntryError(
f"The MQTT config entry is invalid, please correct it: {ex}"
) from ex

# Fetch configuration and add missing defaults for basic options
if (conf := await async_fetch_config(hass, entry)) is None:
# Bail out
return False
# Fetch configuration and add default values
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_yaml = PLATFORM_CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
client = MQTT(hass, entry, conf)
if DOMAIN in hass.data:
mqtt_data = get_mqtt_data(hass)
mqtt_data.config = mqtt_yaml
mqtt_data.client = client
else:
hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml, client=client)
client.start(mqtt_data)

await async_create_certificate_temp_files(hass, dict(entry.data))
mqtt_data.client = MQTT(hass, entry, conf)
# Restore saved subscriptions
if mqtt_data.subscriptions_to_restore:
mqtt_data.client.async_restore_tracked_subscriptions(
Expand Down Expand Up @@ -349,7 +284,7 @@ async def async_publish_service(call: ServiceCall) -> None:
)
return

assert mqtt_data.client is not None and msg_topic is not None
assert msg_topic is not None
await mqtt_data.client.async_publish(msg_topic, payload, qos, retain)

hass.services.async_register(
Expand Down Expand Up @@ -585,7 +520,6 @@ def unsubscribe() -> None:
def is_connected(hass: HomeAssistant) -> bool:
"""Return if MQTT client is connected."""
mqtt_data = get_mqtt_data(hass)
assert mqtt_data.client is not None
return mqtt_data.client.connected


Expand All @@ -603,7 +537,6 @@ async def async_remove_config_entry_device(
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload MQTT dump and publish service when the config entry is unloaded."""
mqtt_data = get_mqtt_data(hass)
assert mqtt_data.client is not None
mqtt_client = mqtt_data.client

# Unload publish and dump services.
Expand Down
28 changes: 16 additions & 12 deletions homeassistant/components/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from .models import (
AsyncMessageCallbackType,
MessageCallbackType,
MqttData,
PublishMessage,
PublishPayloadType,
ReceiveMessage,
Expand Down Expand Up @@ -111,11 +112,11 @@ async def async_publish(
encoding: str | None = DEFAULT_ENCODING,
) -> None:
"""Publish message to a MQTT topic."""
mqtt_data = get_mqtt_data(hass, True)
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
if not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot publish to topic '{topic}', MQTT is not enabled"
)
mqtt_data = get_mqtt_data(hass)
outgoing_payload = payload
if not isinstance(payload, bytes):
if not encoding:
Expand Down Expand Up @@ -161,11 +162,11 @@ async def async_subscribe(
Call the return value to unsubscribe.
"""
mqtt_data = get_mqtt_data(hass, True)
if mqtt_data.client is None or not mqtt_config_entry_enabled(hass):
if not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled"
)
mqtt_data = get_mqtt_data(hass)
# Support for a deprecated callback type was removed with HA core 2023.3.0
# The signature validation code can be removed from HA core 2023.5.0
non_default = 0
Expand Down Expand Up @@ -377,19 +378,16 @@ class MQTT:

_mqttc: mqtt.Client
_last_subscribe: float
_mqtt_data: MqttData

def __init__(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
conf: ConfigType,
self, hass: HomeAssistant, config_entry: ConfigEntry, conf: ConfigType
) -> None:
"""Initialize Home Assistant MQTT client."""
self._mqtt_data = get_mqtt_data(hass)

self.hass = hass
self.config_entry = config_entry
self.conf = conf

self._simple_subscriptions: dict[str, list[Subscription]] = {}
self._wildcard_subscriptions: list[Subscription] = []
self.connected = False
Expand All @@ -415,8 +413,6 @@ def ha_started(_: Event) -> None:

self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started)

self.init_client()

async def async_stop_mqtt(_event: Event) -> None:
"""Stop MQTT component."""
await self.async_disconnect()
Expand All @@ -425,6 +421,14 @@ async def async_stop_mqtt(_event: Event) -> None:
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
)

def start(
self,
mqtt_data: MqttData,
) -> None:
"""Start Home Assistant MQTT client."""
self._mqtt_data = mqtt_data
self.init_client()

@property
def subscriptions(self) -> list[Subscription]:
"""Return the tracked subscriptions."""
Expand Down
59 changes: 14 additions & 45 deletions homeassistant/components/mqtt/config_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,6 @@

DEFAULT_TLS_PROTOCOL = "auto"

DEFAULT_VALUES = {
CONF_BIRTH_MESSAGE: DEFAULT_BIRTH,
CONF_DISCOVERY: DEFAULT_DISCOVERY,
CONF_DISCOVERY_PREFIX: DEFAULT_PREFIX,
CONF_PORT: DEFAULT_PORT,
CONF_PROTOCOL: DEFAULT_PROTOCOL,
CONF_TRANSPORT: DEFAULT_TRANSPORT,
CONF_WILL_MESSAGE: DEFAULT_WILL,
CONF_KEEPALIVE: DEFAULT_KEEPALIVE,
}

PLATFORM_CONFIG_SCHEMA_BASE = vol.Schema(
{
Platform.ALARM_CONTROL_PANEL.value: vol.All(
Expand Down Expand Up @@ -169,9 +158,11 @@
CONFIG_SCHEMA_ENTRY = vol.Schema(
{
vol.Optional(CONF_CLIENT_ID): cv.string,
vol.Optional(CONF_KEEPALIVE): vol.All(vol.Coerce(int), vol.Range(min=15)),
vol.Optional(CONF_BROKER): cv.string,
vol.Optional(CONF_PORT): cv.port,
vol.Optional(CONF_KEEPALIVE, default=DEFAULT_KEEPALIVE): vol.All(
vol.Coerce(int), vol.Range(min=15)
),
vol.Required(CONF_BROKER): cv.string,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
vol.Optional(CONF_USERNAME): cv.string,
vol.Optional(CONF_PASSWORD): cv.string,
vol.Optional(CONF_CERTIFICATE): str,
Expand All @@ -180,13 +171,17 @@
CONF_CLIENT_CERT, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): str,
vol.Optional(CONF_TLS_INSECURE): cv.boolean,
vol.Optional(CONF_PROTOCOL): vol.All(cv.string, vol.In(SUPPORTED_PROTOCOLS)),
vol.Optional(CONF_WILL_MESSAGE): valid_birth_will,
vol.Optional(CONF_BIRTH_MESSAGE): valid_birth_will,
vol.Optional(CONF_DISCOVERY): cv.boolean,
vol.Optional(CONF_PROTOCOL, default=DEFAULT_PROTOCOL): vol.All(
cv.string, vol.In(SUPPORTED_PROTOCOLS)
),
vol.Optional(CONF_WILL_MESSAGE, default=DEFAULT_WILL): valid_birth_will,
vol.Optional(CONF_BIRTH_MESSAGE, default=DEFAULT_BIRTH): valid_birth_will,
vol.Optional(CONF_DISCOVERY, default=DEFAULT_DISCOVERY): cv.boolean,
# discovery_prefix must be a valid publish topic because if no
# state topic is specified, it will be created with the given prefix.
vol.Optional(CONF_DISCOVERY_PREFIX): valid_publish_topic,
vol.Optional(
CONF_DISCOVERY_PREFIX, default=DEFAULT_PREFIX
): valid_publish_topic,
vol.Optional(CONF_TRANSPORT, default=DEFAULT_TRANSPORT): vol.All(
cv.string, vol.In([TRANSPORT_TCP, TRANSPORT_WEBSOCKETS])
),
Expand All @@ -195,32 +190,6 @@
}
)

CONFIG_SCHEMA_BASE = PLATFORM_CONFIG_SCHEMA_BASE.extend(
{
vol.Optional(CONF_CLIENT_ID): cv.string,
vol.Optional(CONF_KEEPALIVE): vol.All(vol.Coerce(int), vol.Range(min=15)),
vol.Optional(CONF_BROKER): cv.string,
vol.Optional(CONF_PORT): cv.port,
vol.Optional(CONF_USERNAME): cv.string,
vol.Optional(CONF_PASSWORD): cv.string,
vol.Optional(CONF_CERTIFICATE): vol.Any("auto", cv.isfile),
vol.Inclusive(
CONF_CLIENT_KEY, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): cv.isfile,
vol.Inclusive(
CONF_CLIENT_CERT, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG
): cv.isfile,
vol.Optional(CONF_TLS_INSECURE): cv.boolean,
vol.Optional(CONF_PROTOCOL): vol.All(cv.string, vol.In(SUPPORTED_PROTOCOLS)),
vol.Optional(CONF_WILL_MESSAGE): valid_birth_will,
vol.Optional(CONF_BIRTH_MESSAGE): valid_birth_will,
vol.Optional(CONF_DISCOVERY): cv.boolean,
# discovery_prefix must be a valid publish topic because if no
# state topic is specified, it will be created with the given prefix.
vol.Optional(CONF_DISCOVERY_PREFIX): valid_publish_topic,
}
)

DEPRECATED_CONFIG_KEYS = [
CONF_BIRTH_MESSAGE,
CONF_BROKER,
Expand Down
1 change: 0 additions & 1 deletion homeassistant/components/mqtt/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,6 @@ async def async_will_remove_from_hass(self) -> None:
def available(self) -> bool:
"""Return if the device is available."""
mqtt_data = get_mqtt_data(self.hass)
assert mqtt_data.client is not None
client = mqtt_data.client
if not client.connected and not self.hass.is_stopping:
return False
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/mqtt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ def write_state_request(self, entity: Entity) -> None:
class MqttData:
"""Keep the MQTT entry data."""

client: MQTT | None = None
config: ConfigType | None = None
client: MQTT
config: ConfigType
debug_info_entities: dict[str, EntityDebugInfo] = field(default_factory=dict)
debug_info_triggers: dict[tuple[str, str], TriggerDebugInfo] = field(
default_factory=dict
Expand Down
5 changes: 1 addition & 4 deletions homeassistant/components/mqtt/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,9 @@ def valid_birth_will(config: ConfigType) -> ConfigType:
return config


def get_mqtt_data(hass: HomeAssistant, ensure_exists: bool = False) -> MqttData:
def get_mqtt_data(hass: HomeAssistant) -> MqttData:
"""Return typed MqttData from hass.data[DATA_MQTT]."""
mqtt_data: MqttData
if ensure_exists:
mqtt_data = hass.data.setdefault(DATA_MQTT, MqttData())
return mqtt_data
mqtt_data = hass.data[DATA_MQTT]
return mqtt_data

Expand Down
Loading

0 comments on commit 4a0d3e8

Please sign in to comment.