Skip to content

Commit

Permalink
Add MQTT integration discovery (home-assistant#41332)
Browse files Browse the repository at this point in the history
* Add MQTT integration discovery

* Add script/hassfest/mqtt.py

* Unsubscribe if config entry exists

* Add homeassistant/generated/mqtt.py

* Fix bad loop

* Improve tests

* Improve tests

* Apply suggestions from code review

Co-authored-by: Fabian Affolter <mail@fabian-affolter.ch>

* Prevent initiating multiple config flows

Co-authored-by: Fabian Affolter <mail@fabian-affolter.ch>
  • Loading branch information
emontnemery and fabaff authored Oct 7, 2020
1 parent 3f263d5 commit 343e5d6
Show file tree
Hide file tree
Showing 21 changed files with 399 additions and 23 deletions.
3 changes: 1 addition & 2 deletions homeassistant/components/mqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
DEFAULT_QOS,
DEFAULT_RETAIN,
DEFAULT_WILL,
DOMAIN,
MQTT_CONNECTED,
MQTT_DISCONNECTED,
PROTOCOL_311,
Expand All @@ -86,8 +87,6 @@

_LOGGER = logging.getLogger(__name__)

DOMAIN = "mqtt"

DATA_MQTT = "mqtt"

SERVICE_PUBLISH = "publish"
Expand Down
2 changes: 2 additions & 0 deletions homeassistant/components/mqtt/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
ATTR_RETAIN: DEFAULT_RETAIN,
}

DOMAIN = "mqtt"

MQTT_CONNECTED = "mqtt_connected"
MQTT_DISCONNECTED = "mqtt_disconnected"

Expand Down
60 changes: 57 additions & 3 deletions homeassistant/components/mqtt/discovery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Support for MQTT discovery."""
import asyncio
import functools
import json
import logging
import re
Expand All @@ -9,9 +10,15 @@
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.loader import async_get_mqtt

from .abbreviations import ABBREVIATIONS, DEVICE_ABBREVIATIONS
from .const import ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_PAYLOAD, ATTR_DISCOVERY_TOPIC
from .const import (
ATTR_DISCOVERY_HASH,
ATTR_DISCOVERY_PAYLOAD,
ATTR_DISCOVERY_TOPIC,
DOMAIN,
)

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -39,7 +46,9 @@
ALREADY_DISCOVERED = "mqtt_discovered_components"
CONFIG_ENTRY_IS_SETUP = "mqtt_config_entry_is_setup"
DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock"
DATA_CONFIG_FLOW_LOCK = "mqtt_discovery_config_flow_lock"
DISCOVERY_UNSUBSCRIBE = "mqtt_discovery_unsubscribe"
INTEGRATION_UNSUBSCRIBE = "mqtt_integration_discovery_unsubscribe"
MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}"
MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}"
LAST_DISCOVERY = "mqtt_last_discovery"
Expand All @@ -65,8 +74,9 @@ async def async_start(
hass: HomeAssistantType, discovery_topic, config_entry=None
) -> bool:
"""Start MQTT Discovery."""
mqtt_integrations = {}

async def async_device_message_received(msg):
async def async_entity_message_received(msg):
"""Process the received message."""
hass.data[LAST_DISCOVERY] = time.time()
payload = msg.payload
Expand Down Expand Up @@ -172,12 +182,52 @@ async def async_device_message_received(msg):
)

hass.data[DATA_CONFIG_ENTRY_LOCK] = asyncio.Lock()
hass.data[DATA_CONFIG_FLOW_LOCK] = asyncio.Lock()
hass.data[CONFIG_ENTRY_IS_SETUP] = set()

hass.data[DISCOVERY_UNSUBSCRIBE] = await mqtt.async_subscribe(
hass, f"{discovery_topic}/#", async_device_message_received, 0
hass, f"{discovery_topic}/#", async_entity_message_received, 0
)
hass.data[LAST_DISCOVERY] = time.time()
mqtt_integrations = await async_get_mqtt(hass)

hass.data[INTEGRATION_UNSUBSCRIBE] = {}

for (integration, topics) in mqtt_integrations.items():

async def async_integration_message_received(integration, msg):
"""Process the received message."""
key = f"{integration}_{msg.subscribed_topic}"

# Lock to prevent initiating many parallel config flows.
# Note: The lock is not intended to prevent a race, only for performance
async with hass.data[DATA_CONFIG_FLOW_LOCK]:
# Already unsubscribed
if key not in hass.data[INTEGRATION_UNSUBSCRIBE]:
return

result = await hass.config_entries.flow.async_init(
integration, context={"source": DOMAIN}, data=msg
)
if (
result
and result["type"] == "abort"
and result["reason"]
in ["already_configured", "single_instance_allowed"]
):
unsub = hass.data[INTEGRATION_UNSUBSCRIBE].pop(key, None)
if unsub is None:
return
unsub()

for topic in topics:
key = f"{integration}_{topic}"
hass.data[INTEGRATION_UNSUBSCRIBE][key] = await mqtt.async_subscribe(
hass,
topic,
functools.partial(async_integration_message_received, integration),
0,
)

return True

Expand All @@ -187,3 +237,7 @@ async def async_stop(hass: HomeAssistantType) -> bool:
if DISCOVERY_UNSUBSCRIBE in hass.data and hass.data[DISCOVERY_UNSUBSCRIBE]:
hass.data[DISCOVERY_UNSUBSCRIBE]()
hass.data[DISCOVERY_UNSUBSCRIBE] = None
if INTEGRATION_UNSUBSCRIBE in hass.data:
for key, unsub in list(hass.data[INTEGRATION_UNSUBSCRIBE].items()):
unsub()
hass.data[INTEGRATION_UNSUBSCRIBE].pop(key)
62 changes: 49 additions & 13 deletions homeassistant/components/tasmota/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,72 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
VERSION = 1
CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_PUSH

def __init__(self):
"""Initialize flow."""
self._prefix = DEFAULT_PREFIX

async def async_step_mqtt(self, discovery_info=None):
"""Handle a flow initialized by MQTT discovery."""
if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")

await self.async_set_unique_id(DOMAIN)

# Validate the topic, will throw if it fails
prefix = discovery_info.subscribed_topic
if prefix.endswith("/#"):
prefix = prefix[:-2]
try:
valid_subscribe_topic(f"{prefix}/#")
except vol.Invalid:
return self.async_abort(reason="invalid_discovery_info")

self._prefix = prefix

return await self.async_step_confirm()

async def async_step_user(self, user_input=None):
"""Handle a flow initialized by the user."""
if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")

return await self.async_step_config()
if self.show_advanced_options:
return await self.async_step_config()
return await self.async_step_confirm()

async def async_step_config(self, user_input=None):
"""Confirm the setup."""
errors = {}
data = {CONF_DISCOVERY_PREFIX: DEFAULT_PREFIX}
data = {CONF_DISCOVERY_PREFIX: self._prefix}

if user_input is not None:
bad_prefix = False
if self.show_advanced_options:
prefix = user_input[CONF_DISCOVERY_PREFIX]
try:
valid_subscribe_topic(f"{prefix}/#")
except vol.Invalid:
errors["base"] = "invalid_discovery_topic"
bad_prefix = True
else:
data = user_input
prefix = user_input[CONF_DISCOVERY_PREFIX]
if prefix.endswith("/#"):
prefix = prefix[:-2]
try:
valid_subscribe_topic(f"{prefix}/#")
except vol.Invalid:
errors["base"] = "invalid_discovery_topic"
bad_prefix = True
else:
data[CONF_DISCOVERY_PREFIX] = prefix
if not bad_prefix:
return self.async_create_entry(title="Tasmota", data=data)

fields = {}
if self.show_advanced_options:
fields[vol.Optional(CONF_DISCOVERY_PREFIX, default=DEFAULT_PREFIX)] = str
fields[vol.Optional(CONF_DISCOVERY_PREFIX, default=self._prefix)] = str

return self.async_show_form(
step_id="config", data_schema=vol.Schema(fields), errors=errors
)

async def async_step_confirm(self, user_input=None):
"""Confirm the setup."""

data = {CONF_DISCOVERY_PREFIX: self._prefix}

if user_input is not None:
return self.async_create_entry(title="Tasmota", data=data)

return self.async_show_form(step_id="confirm")
1 change: 1 addition & 0 deletions homeassistant/components/tasmota/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
"documentation": "https://www.home-assistant.io/integrations/tasmota",
"requirements": ["hatasmota==0.0.10"],
"dependencies": ["mqtt"],
"mqtt": ["tasmota/discovery/#"],
"codeowners": ["@emontnemery"]
}
3 changes: 3 additions & 0 deletions homeassistant/components/tasmota/strings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
{
"config": {
"step": {
"confirm": {
"description": "Do you want to set up Tasmota?"
},
"config": {
"title": "Tasmota",
"description": "Please enter the Tasmota configuration.",
Expand Down
1 change: 1 addition & 0 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SOURCE_HOMEKIT = "homekit"
SOURCE_IMPORT = "import"
SOURCE_INTEGRATION_DISCOVERY = "integration_discovery"
SOURCE_MQTT = "mqtt"
SOURCE_SSDP = "ssdp"
SOURCE_USER = "user"
SOURCE_ZEROCONF = "zeroconf"
Expand Down
12 changes: 12 additions & 0 deletions homeassistant/generated/mqtt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Automatically generated by hassfest.
To update, run python3 -m script.hassfest
"""

# fmt: off

MQTT = {
"tasmota": [
"tasmota/discovery/#"
]
}
1 change: 1 addition & 0 deletions homeassistant/helpers/config_entry_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ async def async_step_discovery(

async_step_zeroconf = async_step_discovery
async_step_ssdp = async_step_discovery
async_step_mqtt = async_step_discovery
async_step_homekit = async_step_discovery

async def async_step_import(self, _: Optional[Dict[str, Any]]) -> Dict[str, Any]:
Expand Down
1 change: 1 addition & 0 deletions homeassistant/helpers/config_entry_oauth2_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ async def async_step_discovery(
return await self.async_step_pick_implementation()

async_step_user = async_step_pick_implementation
async_step_mqtt = async_step_discovery
async_step_ssdp = async_step_discovery
async_step_zeroconf = async_step_discovery
async_step_homekit = async_step_discovery
Expand Down
21 changes: 21 additions & 0 deletions homeassistant/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
cast,
)

from homeassistant.generated.mqtt import MQTT
from homeassistant.generated.ssdp import SSDP
from homeassistant.generated.zeroconf import HOMEKIT, ZEROCONF

Expand Down Expand Up @@ -202,6 +203,21 @@ async def async_get_ssdp(hass: "HomeAssistant") -> Dict[str, List]:
return ssdp


async def async_get_mqtt(hass: "HomeAssistant") -> Dict[str, List]:
"""Return cached list of MQTT mappings."""

mqtt: Dict[str, List] = MQTT.copy()

integrations = await async_get_custom_components(hass)
for integration in integrations.values():
if not integration.mqtt:
continue

mqtt[integration.domain] = integration.mqtt

return mqtt


class Integration:
"""An integration in Home Assistant."""

Expand Down Expand Up @@ -323,6 +339,11 @@ def quality_scale(self) -> Optional[str]:
"""Return Integration Quality Scale."""
return cast(str, self.manifest.get("quality_scale"))

@property
def mqtt(self) -> Optional[list]:
"""Return Integration MQTT entries."""
return cast(List[dict], self.manifest.get("mqtt"))

@property
def ssdp(self) -> Optional[list]:
"""Return Integration SSDP entries."""
Expand Down
1 change: 1 addition & 0 deletions homeassistant/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CONSTRAINT_FILE = "package_constraints.txt"
_LOGGER = logging.getLogger(__name__)
DISCOVERY_INTEGRATIONS: Dict[str, Iterable[str]] = {
"mqtt": ("mqtt",),
"ssdp": ("ssdp",),
"zeroconf": ("zeroconf", "homekit"),
}
Expand Down
2 changes: 2 additions & 0 deletions script/hassfest/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
dependencies,
json,
manifest,
mqtt,
requirements,
services,
ssdp,
Expand All @@ -25,6 +26,7 @@
config_flow,
dependencies,
manifest,
mqtt,
services,
ssdp,
translations,
Expand Down
7 changes: 7 additions & 0 deletions script/hassfest/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def validate_integration(config: Config, integration: Integration):
"config_flow",
"HomeKit information in a manifest requires a config flow to exist",
)
if integration.manifest.get("mqtt"):
integration.add_error(
"config_flow",
"MQTT information in a manifest requires a config flow to exist",
)
if integration.manifest.get("ssdp"):
integration.add_error(
"config_flow",
Expand All @@ -51,6 +56,7 @@ def validate_integration(config: Config, integration: Integration):
"async_step_discovery" in config_flow
or "async_step_hassio" in config_flow
or "async_step_homekit" in config_flow
or "async_step_mqtt" in config_flow
or "async_step_ssdp" in config_flow
or "async_step_zeroconf" in config_flow
)
Expand Down Expand Up @@ -91,6 +97,7 @@ def generate_and_validate(integrations: Dict[str, Integration], config: Config):
if not (
integration.manifest.get("config_flow")
or integration.manifest.get("homekit")
or integration.manifest.get("mqtt")
or integration.manifest.get("ssdp")
or integration.manifest.get("zeroconf")
):
Expand Down
1 change: 1 addition & 0 deletions script/hassfest/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def documentation_url(value: str) -> str:
vol.Required("domain"): str,
vol.Required("name"): str,
vol.Optional("config_flow"): bool,
vol.Optional("mqtt"): [str],
vol.Optional("zeroconf"): [
vol.Any(
str,
Expand Down
Loading

0 comments on commit 343e5d6

Please sign in to comment.