Skip to content

Commit 7f96294

Browse files
committed
added basic operation test cases
1 parent 3dda31a commit 7f96294

File tree

5 files changed

+113
-54
lines changed

5 files changed

+113
-54
lines changed

examples/src/main/python/streaming/test_oprations.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,23 @@
99
conf = SparkConf()
1010
conf.setAppName("PythonStreamingNetworkWordCount")
1111
ssc = StreamingContext(conf=conf, duration=Seconds(1))
12-
13-
test_input = ssc._testInputStream([1,2,3])
14-
class buff:
12+
class Buff:
13+
result = list()
1514
pass
15+
Buff.result = list()
16+
17+
test_input = ssc._testInputStream([range(1,4), range(4,7), range(7,10)])
1618

1719
fm_test = test_input.map(lambda x: (x, 1))
18-
fm_test.test_output(buff)
20+
fm_test.pyprint()
21+
fm_test._test_output(Buff.result)
1922

2023
ssc.start()
2124
while True:
2225
ssc.awaitTermination(50)
23-
try:
24-
buff.result
26+
if len(Buff.result) == 3:
2527
break
26-
except AttributeError:
27-
pass
2828

2929
ssc.stop()
30-
print buff.result
30+
print Buff.result
31+

python/pyspark/streaming/context.py

+25-18
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,14 @@ def textFileStream(self, directory):
123123
"""
124124
return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer())
125125

126-
def stop(self, stopSparkContext=True):
126+
def stop(self, stopSparkContext=True, stopGraceFully=False):
127127
"""
128128
Stop the execution of the streams immediately (does not wait for all received data
129129
to be processed).
130130
"""
131131

132132
try:
133-
self._jssc.stop(stopSparkContext)
133+
self._jssc.stop(stopSparkContext, stopGraceFully)
134134
finally:
135135
# Stop Callback server
136136
SparkContext._gateway.shutdown()
@@ -141,27 +141,34 @@ def checkpoint(self, directory):
141141
"""
142142
self._jssc.checkpoint(directory)
143143

144-
def _testInputStream(self, test_input, numSlices=None):
145-
144+
def _testInputStream(self, test_inputs, numSlices=None):
145+
"""
146+
Generate multiple files to make "stream" in Scala side for test.
147+
Scala chooses one of the files and generates RDD using PythonRDD.readRDDFromFile.
148+
"""
146149
numSlices = numSlices or self._sc.defaultParallelism
147150
# Calling the Java parallelize() method with an ArrayList is too slow,
148151
# because it sends O(n) Py4J commands. As an alternative, serialized
149152
# objects are written to a file and loaded through textFile().
150153

151-
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
152-
153-
# Make sure we distribute data evenly if it's smaller than self.batchSize
154-
if "__len__" not in dir(test_input):
155-
c = list(test_input) # Make it a list so we can compute its length
156-
batchSize = min(len(test_input) // numSlices, self._sc._batchSize)
157-
if batchSize > 1:
158-
serializer = BatchedSerializer(self._sc._unbatched_serializer,
159-
batchSize)
160-
else:
161-
serializer = self._sc._unbatched_serializer
162-
serializer.dump_stream(test_input, tempFile)
163-
154+
tempFiles = list()
155+
for test_input in test_inputs:
156+
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
157+
158+
# Make sure we distribute data evenly if it's smaller than self.batchSize
159+
if "__len__" not in dir(test_input):
160+
c = list(test_input) # Make it a list so we can compute its length
161+
batchSize = min(len(test_input) // numSlices, self._sc._batchSize)
162+
if batchSize > 1:
163+
serializer = BatchedSerializer(self._sc._unbatched_serializer,
164+
batchSize)
165+
else:
166+
serializer = self._sc._unbatched_serializer
167+
serializer.dump_stream(test_input, tempFile)
168+
tempFiles.append(tempFile.name)
169+
170+
jtempFiles = ListConverter().convert(tempFiles, SparkContext._gateway._gateway_client)
164171
jinput_stream = self._jvm.PythonTestInputStream(self._jssc,
165-
tempFile.name,
172+
jtempFiles,
166173
numSlices).asJavaDStream()
167174
return DStream(jinput_stream, self, PickleSerializer())

python/pyspark/streaming/dstream.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ def pyprint(self):
217217
218218
"""
219219
def takeAndPrint(rdd, time):
220-
print "take and print ==================="
221220
taken = rdd.take(11)
222221
print "-------------------------------------------"
223222
print "Time: %s" % (str(time))
@@ -242,13 +241,10 @@ def _test_output(self, buff):
242241
Store data in dstream to buffer to valify the result in tesecase
243242
"""
244243
def get_output(rdd, time):
245-
taken = rdd.take(11)
246-
buff.result = taken
244+
taken = rdd.collect()
245+
buff.append(taken)
247246
self.foreachRDD(get_output)
248247

249-
def output(self):
250-
self._jdstream.outputToFile()
251-
252248

253249
class PipelinedDStream(DStream):
254250
def __init__(self, prev, func, preservesPartitioning=False):

python/pyspark/streaming_tests.py

+76-19
Original file line numberDiff line numberDiff line change
@@ -35,76 +35,133 @@
3535
import time
3636
import unittest
3737
import zipfile
38+
import operator
3839

40+
from pyspark.context import SparkContext
3941
from pyspark.streaming.context import StreamingContext
4042
from pyspark.streaming.duration import *
4143

4244

4345
SPARK_HOME = os.environ["SPARK_HOME"]
4446

45-
class buff:
47+
class StreamOutput:
4648
"""
47-
Buffer for store the output from stream
49+
a class to store the output from stream
4850
"""
49-
result = None
51+
result = list()
5052

5153
class PySparkStreamingTestCase(unittest.TestCase):
5254
def setUp(self):
53-
print "set up"
5455
class_name = self.__class__.__name__
5556
self.ssc = StreamingContext(appName=class_name, duration=Seconds(1))
5657

5758
def tearDown(self):
58-
print "tear donw"
59-
self.ssc.stop()
60-
time.sleep(10)
59+
# Do not call StreamingContext.stop directly because we do not wait to shutdown
60+
# call back server and py4j client
61+
self.ssc._jssc.stop()
62+
self.ssc._sc.stop()
63+
# Why does it long time to terminaete StremaingContext and SparkContext?
64+
# Should we change the sleep time if this depends on machine spec?
65+
time.sleep(5)
66+
67+
@classmethod
68+
def tearDownClass(cls):
69+
time.sleep(5)
70+
SparkContext._gateway._shutdown_callback_server()
6171

6272
class TestBasicOperationsSuite(PySparkStreamingTestCase):
73+
"""
74+
Input and output of this TestBasicOperationsSuite is the equivalent to
75+
Scala TestBasicOperationsSuite.
76+
"""
6377
def setUp(self):
6478
PySparkStreamingTestCase.setUp(self)
65-
buff.result = None
79+
StreamOutput.result = list()
6680
self.timeout = 10 # seconds
6781

6882
def tearDown(self):
6983
PySparkStreamingTestCase.tearDown(self)
7084

85+
@classmethod
86+
def tearDownClass(cls):
87+
PySparkStreamingTestCase.tearDownClass()
88+
7189
def test_map(self):
90+
"""Basic operation test for DStream.map"""
7291
test_input = [range(1,5), range(5,9), range(9, 13)]
7392
def test_func(dstream):
7493
return dstream.map(lambda x: str(x))
75-
expected = map(str, test_input)
76-
output = self.run_stream(test_input, test_func)
77-
self.assertEqual(output, expected)
94+
expected_output = map(lambda x: map(lambda y: str(y), x), test_input)
95+
output = self._run_stream(test_input, test_func, expected_output)
96+
self.assertEqual(expected_output, output)
7897

7998
def test_flatMap(self):
99+
"""Basic operation test for DStream.faltMap"""
80100
test_input = [range(1,5), range(5,9), range(9, 13)]
81101
def test_func(dstream):
82102
return dstream.flatMap(lambda x: (x, x * 2))
83-
# Maybe there be good way to create flatmap
84-
excepted = map(lambda x: list(chain.from_iterable((map(lambda y:[y, y*2], x)))),
103+
expected_output = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))),
85104
test_input)
86-
output = self.run_stream(test_input, test_func)
105+
output = self._run_stream(test_input, test_func, expected_output)
106+
self.assertEqual(expected_output, output)
107+
108+
def test_filter(self):
109+
"""Basic operation test for DStream.filter"""
110+
test_input = [range(1,5), range(5,9), range(9, 13)]
111+
def test_func(dstream):
112+
return dstream.filter(lambda x: x % 2 == 0)
113+
expected_output = map(lambda x: filter(lambda y: y % 2 == 0, x), test_input)
114+
output = self._run_stream(test_input, test_func, expected_output)
115+
self.assertEqual(expected_output, output)
116+
117+
def test_count(self):
118+
"""Basic operation test for DStream.count"""
119+
test_input = [[], [1], range(1, 3), range(1,4), range(1,5)]
120+
def test_func(dstream):
121+
return dstream.count()
122+
expected_output = map(lambda x: [len(x)], test_input)
123+
output = self._run_stream(test_input, test_func, expected_output)
124+
self.assertEqual(expected_output, output)
125+
126+
def test_reduce(self):
127+
"""Basic operation test for DStream.reduce"""
128+
test_input = [range(1,5), range(5,9), range(9, 13)]
129+
def test_func(dstream):
130+
return dstream.reduce(operator.add)
131+
expected_output = map(lambda x: [reduce(operator.add, x)], test_input)
132+
output = self._run_stream(test_input, test_func, expected_output)
133+
self.assertEqual(expected_output, output)
134+
135+
def test_reduceByKey(self):
136+
"""Basic operation test for DStream.reduceByKey"""
137+
test_input = [["a", "a", "b"], ["", ""], []]
138+
def test_func(dstream):
139+
return dstream.map(lambda x: (x, 1)).reduceByKey(operator.add)
140+
expected_output = [[("a", 2), ("b", 1)],[("", 2)], []]
141+
output = self._run_stream(test_input, test_func, expected_output)
142+
self.assertEqual(expected_output, output)
87143

88-
def run_stream(self, test_input, test_func):
144+
def _run_stream(self, test_input, test_func, expected_output):
145+
"""Start stream and return the output"""
89146
# Generate input stream with user-defined input
90147
test_input_stream = self.ssc._testInputStream(test_input)
91148
# Applyed test function to stream
92149
test_stream = test_func(test_input_stream)
93150
# Add job to get outpuf from stream
94-
test_stream._test_output(buff)
151+
test_stream._test_output(StreamOutput.result)
95152
self.ssc.start()
96153

97154
start_time = time.time()
155+
# loop until get the result from stream
98156
while True:
99157
current_time = time.time()
100158
# check time out
101159
if (current_time - start_time) > self.timeout:
102-
self.ssc.stop()
103160
break
104161
self.ssc.awaitTermination(50)
105-
if buff.result is not None:
162+
if len(expected_output) == len(StreamOutput.result):
106163
break
107-
return buff.result
164+
return StreamOutput.result
108165

109166
if __name__ == "__main__":
110167
unittest.main()

streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala

-2
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ class PythonDStream[T: ClassTag](
5656
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
5757
parent.getOrCompute(validTime) match{
5858
case Some(rdd) =>
59-
logInfo("RDD ID in python DStream ===========")
60-
logInfo("RDD id " + rdd.id)
6159
val pythonRDD = new PythonRDD(rdd, command, envVars, pythonIncludes, preservePartitoning, pythonExec, broadcastVars, accumulator)
6260
Some(pythonRDD.asJavaRDD.rdd)
6361
case None => None

0 commit comments

Comments
 (0)