Skip to content

Commit

Permalink
correct MQTT subscription filter (home-assistant#7269)
Browse files Browse the repository at this point in the history
* correct MQTT subscription filter

* wildcard handling (#) fixed

* wildcard handling (#) fixed

* added tests for topic subscription like +/something/#

* function names changed (line too long)

* using raw strings for regular expression
import order changed
  • Loading branch information
amigian74 authored and balloob committed May 2, 2017
1 parent 570c554 commit 0e08925
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 6 deletions.
21 changes: 15 additions & 6 deletions homeassistant/components/mqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import socket
import time
import ssl
import re
import requests.certs

import voluptuous as vol
Expand Down Expand Up @@ -639,12 +640,20 @@ def _raise_on_error(result):

def _match_topic(subscription, topic):
"""Test if topic matches subscription."""
reg_ex_parts = []
suffix = ""
if subscription.endswith('#'):
return (subscription[:-2] == topic or
topic.startswith(subscription[:-1]))

subscription = subscription[:-2]
suffix = "(.*)"
sub_parts = subscription.split('/')
topic_parts = topic.split('/')
for sub_part in sub_parts:
if sub_part == "+":
reg_ex_parts.append(r"([^\/]+)")
else:
reg_ex_parts.append(sub_part)

reg_ex = "^" + (r'\/'.join(reg_ex_parts)) + suffix + "$"

reg = re.compile(reg_ex)

return (len(sub_parts) == len(topic_parts) and
all(a == b for a, b in zip(sub_parts, topic_parts) if a != '+'))
return reg.match(topic) is not None
40 changes: 40 additions & 0 deletions tests/components/mqtt/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,46 @@ def test_subscribe_topic_subtree_wildcard_no_match(self):
self.hass.block_till_done()
self.assertEqual(0, len(self.calls))

def test_subscribe_topic_level_wildcard_and_wildcard_root_topic(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)

fire_mqtt_message(self.hass, 'hi/test-topic', 'test-payload')

self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual('hi/test-topic', self.calls[0][0])
self.assertEqual('test-payload', self.calls[0][1])

def test_subscribe_topic_level_wildcard_and_wildcard_subtree_topic(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)

fire_mqtt_message(self.hass, 'hi/test-topic/here-iam', 'test-payload')

self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual('hi/test-topic/here-iam', self.calls[0][0])
self.assertEqual('test-payload', self.calls[0][1])

def test_subscribe_topic_level_wildcard_and_wildcard_level_no_match(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)

fire_mqtt_message(self.hass, 'hi/here-iam/test-topic', 'test-payload')

self.hass.block_till_done()
self.assertEqual(0, len(self.calls))

def test_subscribe_topic_level_wildcard_and_wildcard_no_match(self):
"""Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, '+/test-topic/#', self.record_calls)

fire_mqtt_message(self.hass, 'hi/another-test-topic', 'test-payload')

self.hass.block_till_done()
self.assertEqual(0, len(self.calls))

def test_subscribe_binary_topic(self):
"""Test the subscription to a binary topic."""
mqtt.subscribe(self.hass, 'test-topic', self.record_calls,
Expand Down

0 comments on commit 0e08925

Please sign in to comment.