Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make MQTT reconnection logic more resilient and fix race condition #10133

Merged
merged 1 commit into from
Nov 17, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 12 additions & 22 deletions homeassistant/components/mqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,8 @@ def __init__(self, hass, broker, port, client_id, keepalive, username,
self.broker = broker
self.port = port
self.keepalive = keepalive
self.topics = {}
self.wanted_topics = {}
self.subscribed_topics = {}
self.progress = {}
self.birth_message = birth_message
self._mqttc = None
Expand Down Expand Up @@ -526,22 +527,22 @@ def async_subscribe(self, topic, qos):
raise HomeAssistantError("topic need to be a string!")

with (yield from self._paho_lock):
if topic in self.topics:
if topic in self.subscribed_topics:
return

self.wanted_topics[topic] = qos
result, mid = yield from self.hass.async_add_job(
self._mqttc.subscribe, topic, qos)

_raise_on_error(result)
self.progress[mid] = topic
self.topics[topic] = None

@asyncio.coroutine
def async_unsubscribe(self, topic):
"""Unsubscribe from topic.

This method is a coroutine.
"""
self.wanted_topics.pop(topic, None)
result, mid = yield from self.hass.async_add_job(
self._mqttc.unsubscribe, topic)

Expand All @@ -562,15 +563,10 @@ def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code):
self._mqttc.disconnect()
return

old_topics = self.topics

self.topics = {key: value for key, value in self.topics.items()
if value is None}

for topic, qos in old_topics.items():
# qos is None if we were in process of subscribing
if qos is not None:
self.hass.add_job(self.async_subscribe, topic, qos)
self.progress = {}
self.subscribed_topics = {}
for topic, qos in self.wanted_topics.items():
self.hass.add_job(self.async_subscribe, topic, qos)

if self.birth_message:
self.hass.add_job(self.async_publish(
Expand All @@ -584,7 +580,7 @@ def _mqtt_on_subscribe(self, _mqttc, _userdata, mid, granted_qos):
topic = self.progress.pop(mid, None)
if topic is None:
return
self.topics[topic] = granted_qos[0]
self.subscribed_topics[topic] = granted_qos[0]

def _mqtt_on_message(self, _mqttc, _userdata, msg):
"""Message received callback."""
Expand All @@ -598,18 +594,12 @@ def _mqtt_on_unsubscribe(self, _mqttc, _userdata, mid, granted_qos):
topic = self.progress.pop(mid, None)
if topic is None:
return
self.topics.pop(topic, None)
self.subscribed_topics.pop(topic, None)

def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code):
"""Disconnected callback."""
self.progress = {}
self.topics = {key: value for key, value in self.topics.items()
if value is not None}

# Remove None values from topic list
for key in list(self.topics):
if self.topics[key] is None:
self.topics.pop(key)
self.subscribed_topics = {}

# When disconnected because of calling disconnect()
if result_code == 0:
Expand Down
28 changes: 18 additions & 10 deletions tests/components/mqtt/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,12 @@ def test_mqtt_disconnect_tries_no_reconnect_on_stop(self):
@mock.patch('homeassistant.components.mqtt.time.sleep')
def test_mqtt_disconnect_tries_reconnect(self, mock_sleep):
"""Test the re-connect tries."""
self.hass.data['mqtt'].topics = {
self.hass.data['mqtt'].subscribed_topics = {
'test/topic': 1,
'test/progress': None
}
self.hass.data['mqtt'].wanted_topics = {
'test/progress': 0,
'test/topic': 2,
}
self.hass.data['mqtt'].progress = {
1: 'test/progress'
Expand All @@ -403,7 +406,9 @@ def test_mqtt_disconnect_tries_reconnect(self, mock_sleep):
self.assertEqual([1, 2, 4],
[call[1][0] for call in mock_sleep.mock_calls])

self.assertEqual({'test/topic': 1}, self.hass.data['mqtt'].topics)
self.assertEqual({'test/topic': 2, 'test/progress': 0},
self.hass.data['mqtt'].wanted_topics)
self.assertEqual({}, self.hass.data['mqtt'].subscribed_topics)
self.assertEqual({}, self.hass.data['mqtt'].progress)

def test_invalid_mqtt_topics(self):
Expand Down Expand Up @@ -556,12 +561,15 @@ def test_mqtt_subscribes_topics_on_connect(hass):
"""Test subscription to topic on connect."""
mqtt_client = yield from mock_mqtt_client(hass)

prev_topics = OrderedDict()
prev_topics['topic/test'] = 1,
prev_topics['home/sensor'] = 2,
prev_topics['still/pending'] = None
subscribed_topics = OrderedDict()
subscribed_topics['topic/test'] = 1
subscribed_topics['home/sensor'] = 2

wanted_topics = subscribed_topics.copy()
wanted_topics['still/pending'] = 0

hass.data['mqtt'].topics = prev_topics
hass.data['mqtt'].wanted_topics = wanted_topics
hass.data['mqtt'].subscribed_topics = subscribed_topics
hass.data['mqtt'].progress = {1: 'still/pending'}

# Return values for subscribe calls (rc, mid)
Expand All @@ -574,7 +582,7 @@ def test_mqtt_subscribes_topics_on_connect(hass):

assert not mqtt_client.disconnect.called

expected = [(topic, qos) for topic, qos in prev_topics.items()
if qos is not None]
expected = [(topic, qos) for topic, qos in wanted_topics.items()]

assert [call[1][1:] for call in hass.add_job.mock_calls] == expected
assert hass.data['mqtt'].progress == {}