Skip to content

Commit d07f454

Browse files
committed
Register StreamingListerner before starting StreamingContext; Revert unncessary changes; fix the python unit test
1 parent a6747cb commit d07f454

File tree

3 files changed

+49
-43
lines changed

3 files changed

+49
-43
lines changed

external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,41 +20,42 @@ package org.apache.spark.streaming.mqtt
2020
import scala.concurrent.duration._
2121
import scala.language.postfixOps
2222

23-
import org.scalatest.BeforeAndAfterAll
23+
import org.scalatest.BeforeAndAfter
2424
import org.scalatest.concurrent.Eventually
2525

2626
import org.apache.spark.{SparkConf, SparkFunSuite}
2727
import org.apache.spark.storage.StorageLevel
2828
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
2929

30-
class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll {
30+
class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {
31+
32+
private val batchDuration = Milliseconds(500)
33+
private val master = "local[2]"
34+
private val framework = this.getClass.getSimpleName
35+
private val topic = "def"
3136

32-
private val topic = "topic"
3337
private var ssc: StreamingContext = _
3438
private var MQTTTestUtils: MQTTTestUtils = _
3539

36-
override def beforeAll(): Unit = {
40+
before {
41+
ssc = new StreamingContext(master, framework, batchDuration)
3742
MQTTTestUtils = new MQTTTestUtils
3843
MQTTTestUtils.setup()
3944
}
4045

41-
override def afterAll(): Unit = {
46+
after {
4247
if (ssc != null) {
4348
ssc.stop()
4449
ssc = null
4550
}
46-
4751
if (MQTTTestUtils != null) {
4852
MQTTTestUtils.teardown()
4953
MQTTTestUtils = null
5054
}
5155
}
5256

5357
test("mqtt input stream") {
54-
val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
55-
ssc = new StreamingContext(sparkConf, Milliseconds(500))
5658
val sendMessage = "MQTT demo for spark streaming"
57-
5859
val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + MQTTTestUtils.brokerUri, topic,
5960
StorageLevel.MEMORY_ONLY)
6061

@@ -65,6 +66,9 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterA
6566
receiveMessage
6667
}
6768
}
69+
70+
MQTTTestUtils.registerStreamingListener(ssc)
71+
6872
ssc.start()
6973

7074
// wait for the receiver to start before publishing data, or we risk failing

external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.concurrent.{CountDownLatch, TimeUnit}
2222

2323
import scala.language.postfixOps
2424

25+
import com.google.common.base.Charsets.UTF_8
2526
import org.apache.activemq.broker.{BrokerService, TransportConnector}
2627
import org.apache.commons.lang3.RandomUtils
2728
import org.eclipse.paho.client.mqttv3._
@@ -46,6 +47,8 @@ private class MQTTTestUtils extends Logging {
4647
private var broker: BrokerService = _
4748
private var connector: TransportConnector = _
4849

50+
private var receiverStartedLatch = new CountDownLatch(1)
51+
4952
def brokerUri: String = {
5053
s"$brokerHost:$brokerPort"
5154
}
@@ -69,6 +72,8 @@ private class MQTTTestUtils extends Logging {
6972
connector.stop()
7073
connector = null
7174
}
75+
Utils.deleteRecursively(persistenceDir)
76+
receiverStartedLatch = null
7277
}
7378

7479
private def findFreePort(): Int = {
@@ -88,7 +93,7 @@ private class MQTTTestUtils extends Logging {
8893
client.connect()
8994
if (client.isConnected) {
9095
val msgTopic = client.getTopic(topic)
91-
val message = new MqttMessage(data.getBytes("utf-8"))
96+
val message = new MqttMessage(data.getBytes(UTF_8))
9297
message.setQos(1)
9398
message.setRetained(true)
9499

@@ -110,27 +115,37 @@ private class MQTTTestUtils extends Logging {
110115
}
111116

112117
/**
113-
* Block until at least one receiver has started or timeout occurs.
118+
* Call this one before starting StreamingContext so that we won't miss the
119+
* StreamingListenerReceiverStarted event.
114120
*/
115-
def waitForReceiverToStart(ssc: StreamingContext) : Unit = {
116-
val latch = new CountDownLatch(1)
121+
def registerStreamingListener(jssc: JavaStreamingContext): Unit = {
122+
registerStreamingListener(jssc.ssc)
123+
}
124+
125+
/**
126+
* Call this one before starting StreamingContext so that we won't miss the
127+
* StreamingListenerReceiverStarted event.
128+
*/
129+
def registerStreamingListener(ssc: StreamingContext): Unit = {
117130
ssc.addStreamingListener(new StreamingListener {
118131
override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
119-
latch.countDown()
132+
receiverStartedLatch.countDown()
120133
}
121134
})
122-
123-
assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
124135
}
125136

126-
def waitForReceiverToStart(jssc: JavaStreamingContext) : Unit = {
127-
val latch = new CountDownLatch(1)
128-
jssc.addStreamingListener(new StreamingListener {
129-
override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
130-
latch.countDown()
131-
}
132-
})
137+
/**
138+
* Block until at least one receiver has started or timeout occurs.
139+
*/
140+
def waitForReceiverToStart(jssc: JavaStreamingContext): Unit = {
141+
waitForReceiverToStart(jssc.ssc)
142+
}
133143

134-
assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
144+
/**
145+
* Block until at least one receiver has started or timeout occurs.
146+
*/
147+
def waitForReceiverToStart(ssc: StreamingContext): Unit = {
148+
assert(
149+
receiverStartedLatch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
135150
}
136151
}

python/pyspark/streaming/tests.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -863,31 +863,18 @@ def getOutput(_, rdd):
863863
self.ssc.start()
864864
return result
865865

866-
def _publishData(self, topic, data):
867-
start_time = time.time()
868-
while True:
869-
try:
870-
self._MQTTTestUtils.publishData(topic, data)
871-
break
872-
except:
873-
if time.time() - start_time < self.timeout:
874-
time.sleep(0.01)
875-
else:
876-
raise
877-
878-
def _validateStreamResult(self, sendData, result):
879-
receiveData = ''.join(result[0])
880-
self.assertEqual(sendData, receiveData)
881-
882866
def test_mqtt_stream(self):
883867
"""Test the Python MQTT stream API."""
884868
sendData = "MQTT demo for spark streaming"
885869
topic = self._randomTopic()
870+
self._MQTTTestUtils.registerStreamingListener(self.ssc._jssc)
886871
result = self._startContext(topic)
887872
self._MQTTTestUtils.waitForReceiverToStart(self.ssc._jssc)
888-
self._publishData(topic, sendData)
889-
self.wait_for(result, len(sendData))
890-
self._validateStreamResult(sendData, result)
873+
self._MQTTTestUtils.publishData(topic, sendData)
874+
self.wait_for(result, 1)
875+
# Because "publishData" sends duplicate messages, here we should use > 0
876+
self.assertTrue(len(result) > 0)
877+
self.assertEqual(sendData, result[0])
891878

892879

893880
def search_kafka_assembly_jar():

0 commit comments

Comments
 (0)