Skip to content

Commit 6ebceca

Browse files
committed
add more tests
1 parent c40c52d commit 6ebceca

File tree

3 files changed

+137
-61
lines changed

3 files changed

+137
-61
lines changed

python/pyspark/streaming/dstream.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def saveAsTextFiles(self, prefix, suffix=None):
286286
Save this DStream as a text file, using string representations of elements.
287287
"""
288288

289-
def saveAsTextFile(rdd, time):
289+
def saveAsTextFile(time, rdd):
290290
"""
291291
Closure to save element in RDD in DStream as Pickled data in file.
292292
This closure is called by py4j callback server.
@@ -303,7 +303,7 @@ def saveAsPickleFiles(self, prefix, suffix=None):
303303
is 10.
304304
"""
305305

306-
def saveAsPickleFile(rdd, time):
306+
def saveAsPickleFile(time, rdd):
307307
"""
308308
Closure to save element in RDD in the DStream as Pickled data in file.
309309
This closure is called by py4j callback server.
@@ -388,7 +388,7 @@ def leftOuterJoin(self, other, numPartitions=None):
388388
Hash partitioning is used to generate the RDDs with `numPartitions`
389389
partitions.
390390
"""
391-
return self.transformWith(lambda a, b: a.leftOuterJion(b, numPartitions), other)
391+
return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other)
392392

393393
def rightOuterJoin(self, other, numPartitions=None):
394394
"""
@@ -502,7 +502,7 @@ def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=Non
502502
@param numPartitions number of partitions of each RDD in the new DStream.
503503
"""
504504
keyed = self.map(lambda x: (x, 1))
505-
counted = keyed.reduceByKeyAndWindow(lambda a, b: a + b, lambda a, b: a - b,
505+
counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub,
506506
windowDuration, slideDuration, numPartitions)
507507
return counted.filter(lambda (k, v): v > 0).count()
508508

python/pyspark/streaming/tests.py

Lines changed: 118 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,12 @@
1515
# limitations under the License.
1616
#
1717

18-
"""
19-
Unit tests for Python SparkStreaming; additional tests are implemented as doctests in
20-
individual modules.
21-
22-
Callback server is sometimes unstable sometimes, which cause error in test case.
23-
But this is very rare case.
24-
"""
18+
import os
2519
from itertools import chain
2620
import time
2721
import operator
2822
import unittest
23+
import tempfile
2924

3025
from pyspark.context import SparkContext
3126
from pyspark.streaming.context import StreamingContext
@@ -45,16 +40,20 @@ def setUp(self):
4540
def tearDown(self):
4641
self.ssc.stop()
4742

48-
def _test_func(self, input, func, expected, sort=False):
43+
def _test_func(self, input, func, expected, sort=False, input2=None):
4944
"""
5045
@param input: dataset for the test. This should be list of lists.
5146
@param func: wrapped function. This function should return PythonDStream object.
5247
@param expected: expected output for this testcase.
5348
"""
54-
# Generate input stream with user-defined input.
5549
input_stream = self.ssc.queueStream(input)
50+
input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None
5651
# Apply test function to stream.
57-
stream = func(input_stream)
52+
if input2:
53+
stream = func(input_stream, input_stream2)
54+
else:
55+
stream = func(input_stream)
56+
5857
result = stream.collect()
5958
self.ssc.start()
6059

@@ -92,7 +91,7 @@ def test_take(self):
9291
def test_first(self):
9392
input = [range(10)]
9493
dstream = self.ssc.queueStream(input)
95-
self.assertEqual(0, dstream)
94+
self.assertEqual(0, dstream.first())
9695

9796
def test_map(self):
9897
"""Basic operation test for DStream.map."""
@@ -238,55 +237,122 @@ def add(a, b):
238237
[("a", "11"), ("b", "1"), ("", "111")]]
239238
self._test_func(input, func, expected, sort=True)
240239

240+
def test_repartition(self):
241+
input = [range(1, 5), range(5, 9)]
242+
rdds = [self.sc.parallelize(r, 2) for r in input]
243+
244+
def func(dstream):
245+
return dstream.repartitions(1).glom()
246+
expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]]
247+
self._test_func(rdds, func, expected)
248+
241249
def test_union(self):
242-
input1 = [range(3), range(5), range(1), range(6)]
243-
input2 = [range(3, 6), range(5, 6), range(1, 6)]
250+
input1 = [range(3), range(5), range(6)]
251+
input2 = [range(3, 6), range(5, 6)]
244252

245-
d1 = self.ssc.queueStream(input1)
246-
d2 = self.ssc.queueStream(input2)
247-
d = d1.union(d2)
248-
result = d.collect()
249-
expected = [range(6), range(6), range(6), range(6)]
253+
def func(d1, d2):
254+
return d1.union(d2)
250255

251-
self.ssc.start()
252-
start_time = time.time()
253-
# Loop until get the expected the number of the result from the stream.
254-
while True:
255-
current_time = time.time()
256-
# Check time out.
257-
if (current_time - start_time) > self.timeout * 2:
258-
break
259-
# StreamingContext.awaitTermination is not used to wait because
260-
# if py4j server is called every 50 milliseconds, it gets an error.
261-
time.sleep(0.05)
262-
# Check if the output is the same length of expected output.
263-
if len(expected) == len(result):
264-
break
265-
self.assertEqual(expected, result)
256+
expected = [range(6), range(6), range(6)]
257+
self._test_func(input1, func, expected, input2=input2)
258+
259+
def test_cogroup(self):
260+
input = [[(1, 1), (2, 1), (3, 1)],
261+
[(1, 1), (1, 1), (1, 1), (2, 1)],
262+
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]]
263+
input2 = [[(1, 2)],
264+
[(4, 1)],
265+
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]]
266+
267+
def func(d1, d2):
268+
return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs)))
269+
270+
expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))],
271+
[(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))],
272+
[("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]]
273+
self._test_func(input, func, expected, sort=True, input2=input2)
274+
275+
def test_join(self):
276+
input = [[('a', 1), ('b', 2)]]
277+
input2 = [[('b', 3), ('c', 4)]]
278+
279+
def func(a, b):
280+
return a.join(b)
281+
282+
expected = [[('b', (2, 3))]]
283+
self._test_func(input, func, expected, True, input2)
284+
285+
def test_left_outer_join(self):
286+
input = [[('a', 1), ('b', 2)]]
287+
input2 = [[('b', 3), ('c', 4)]]
288+
289+
def func(a, b):
290+
return a.leftOuterJoin(b)
291+
292+
expected = [[('a', (1, None)), ('b', (2, 3))]]
293+
self._test_func(input, func, expected, True, input2)
294+
295+
def test_right_outer_join(self):
296+
input = [[('a', 1), ('b', 2)]]
297+
input2 = [[('b', 3), ('c', 4)]]
298+
299+
def func(a, b):
300+
return a.rightOuterJoin(b)
301+
302+
expected = [[('b', (2, 3)), ('c', (None, 4))]]
303+
self._test_func(input, func, expected, True, input2)
304+
305+
def test_full_outer_join(self):
306+
input = [[('a', 1), ('b', 2)]]
307+
input2 = [[('b', 3), ('c', 4)]]
308+
309+
def func(a, b):
310+
return a.fullOuterJoin(b)
311+
312+
expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]]
313+
self._test_func(input, func, expected, True, input2)
266314

267315

268316
class TestWindowFunctions(PySparkStreamingTestCase):
269317

270-
timeout = 15
318+
timeout = 20
319+
320+
def test_window(self):
321+
input = [range(1), range(2), range(3), range(4), range(5)]
322+
323+
def func(dstream):
324+
return dstream.window(3, 1).count()
325+
326+
expected = [[1], [3], [6], [9], [12], [9], [5]]
327+
self._test_func(input, func, expected)
271328

272329
def test_count_by_window(self):
273-
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
330+
input = [range(1), range(2), range(3), range(4), range(5)]
274331

275332
def func(dstream):
276-
return dstream.countByWindow(4, 1)
333+
return dstream.countByWindow(3, 1)
277334

278-
expected = [[1], [3], [6], [9], [12], [15], [11], [6]]
335+
expected = [[1], [3], [6], [9], [12], [9], [5]]
279336
self._test_func(input, func, expected)
280337

281338
def test_count_by_window_large(self):
282339
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
283340

284341
def func(dstream):
285-
return dstream.countByWindow(6, 1)
342+
return dstream.countByWindow(5, 1)
286343

287344
expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
288345
self._test_func(input, func, expected)
289346

347+
def test_count_by_value_and_window(self):
348+
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
349+
350+
def func(dstream):
351+
return dstream.countByValueAndWindow(6, 1)
352+
353+
expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]]
354+
self._test_func(input, func, expected)
355+
290356
def test_group_by_key_and_window(self):
291357
input = [[('a', i)] for i in range(5)]
292358

@@ -359,6 +425,20 @@ def test_queueStream(self):
359425
time.sleep(1)
360426
self.assertEqual(input, result[:3])
361427

428+
# TODO: test textFileStream
429+
# def test_textFileStream(self):
430+
# input = [range(i) for i in range(3)]
431+
# dstream = self.ssc.queueStream(input)
432+
# d = os.path.join(tempfile.gettempdir(), str(id(self)))
433+
# if not os.path.exists(d):
434+
# os.makedirs(d)
435+
# dstream.saveAsTextFiles(os.path.join(d, 'test'))
436+
# dstream2 = self.ssc.textFileStream(d)
437+
# result = dstream2.collect()
438+
# self.ssc.start()
439+
# time.sleep(2)
440+
# self.assertEqual(input, result[:3])
441+
362442
def test_union(self):
363443
input = [range(i) for i in range(3)]
364444
dstream = self.ssc.queueStream(input)

streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ private[spark] object PythonDStream {
9393
}
9494

9595
// helper function for ssc.transform()
96-
def callTransform(ssc: JavaStreamingContext, jdsteams: JList[JavaDStream[_]], pyfunc: PythonRDDFunction)
96+
def callTransform(ssc: JavaStreamingContext, jdsteams: JList[JavaDStream[_]],
97+
pyfunc: PythonRDDFunction)
9798
:JavaDStream[Array[Byte]] = {
9899
val func = new RDDFunction(pyfunc)
99100
ssc.transform(jdsteams, func)
@@ -210,9 +211,9 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
210211

211212
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
212213
val currentTime = validTime
213-
val currentWindow = new Interval(currentTime - windowDuration + parent.slideDuration,
214+
val current = new Interval(currentTime - windowDuration,
214215
currentTime)
215-
val previousWindow = currentWindow - slideDuration
216+
val previous = current - slideDuration
216217

217218
// _____________________________
218219
// | previous window _________|___________________
@@ -225,35 +226,30 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
225226
// old RDDs new RDDs
226227
//
227228

228-
// Get the RDD of the reduced value of the previous window
229-
val previousWindowRDD = getOrCompute(previousWindow.endTime)
229+
val previousRDD = getOrCompute(previous.endTime)
230230

231-
if (pinvReduceFunc != null && previousWindowRDD.isDefined
231+
if (pinvReduceFunc != null && previousRDD.isDefined
232232
// for small window, reduce once will be better than twice
233-
&& windowDuration > slideDuration * 5) {
233+
&& windowDuration >= slideDuration * 5) {
234234

235235
// subtract the values from old RDDs
236-
val oldRDDs =
237-
parent.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration)
238-
val subbed = if (oldRDDs.size > 0) {
239-
invReduceFunc(previousWindowRDD, Some(ssc.sc.union(oldRDDs)), validTime)
236+
val oldRDDs = parent.slice(previous.beginTime + parent.slideDuration, current.beginTime)
237+
val subtracted = if (oldRDDs.size > 0) {
238+
invReduceFunc(previousRDD, Some(ssc.sc.union(oldRDDs)), validTime)
240239
} else {
241-
previousWindowRDD
240+
previousRDD
242241
}
243242

244243
// add the RDDs of the reduced values in "new time steps"
245-
val newRDDs =
246-
parent.slice(previousWindow.endTime, currentWindow.endTime - parent.slideDuration)
247-
244+
val newRDDs = parent.slice(previous.endTime + parent.slideDuration, current.endTime)
248245
if (newRDDs.size > 0) {
249-
reduceFunc(subbed, Some(ssc.sc.union(newRDDs)), validTime)
246+
reduceFunc(subtracted, Some(ssc.sc.union(newRDDs)), validTime)
250247
} else {
251-
subbed
248+
subtracted
252249
}
253250
} else {
254251
// Get the RDDs of the reduced values in current window
255-
val currentRDDs =
256-
parent.slice(currentWindow.beginTime, currentWindow.endTime - parent.slideDuration)
252+
val currentRDDs = parent.slice(current.beginTime + parent.slideDuration, current.endTime)
257253
if (currentRDDs.size > 0) {
258254
reduceFunc(None, Some(ssc.sc.union(currentRDDs)), validTime)
259255
} else {

0 commit comments

Comments
 (0)