Skip to content

Commit 270a9e1

Browse files
committed
added mapValues and flatMapVaules WIP for glom and mapPartitions test
1 parent bb10956 commit 270a9e1

File tree

3 files changed

+101
-18
lines changed

3 files changed

+101
-18
lines changed

python/pyspark/streaming/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def _testInputStream(self, test_inputs, numSlices=None):
140140
"""
141141
Generate multiple files to make "stream" in Scala side for test.
142142
Scala chooses one of the files and generates RDD using PythonRDD.readRDDFromFile.
143+
144+
QueStream maybe good way to implement this function
143145
"""
144146
numSlices = numSlices or self._sc.defaultParallelism
145147
# Calling the Java parallelize() method with an ArrayList is too slow,

python/pyspark/streaming/dstream.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,31 @@ def __init__(self, jdstream, ssc, jrdd_deserializer):
3535
self.ctx = ssc._sc
3636
self._jrdd_deserializer = jrdd_deserializer
3737

38+
def context(self):
39+
"""
40+
Return the StreamingContext associated with this DStream
41+
"""
42+
return self._ssc
43+
3844
def count(self):
3945
"""
4046
Return a new DStream which contains the number of elements in this DStream.
4147
"""
42-
return self._mapPartitions(lambda i: [sum(1 for _ in i)])._sum()
48+
return self.mapPartitions(lambda i: [sum(1 for _ in i)])._sum()
4349

4450
def _sum(self):
4551
"""
4652
Add up the elements in this DStream.
4753
"""
48-
return self._mapPartitions(lambda x: [sum(x)]).reduce(operator.add)
54+
return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add)
4955

5056
def print_(self, label=None):
5157
"""
5258
Since print is reserved name for python, we cannot define a "print" method function.
5359
This function prints serialized data in RDD in DStream because Scala and Java cannot
54-
deserialized pickled python object. Please use DStream.pyprint() instead to print results.
60+
deserialized pickled python object. Please use DStream.pyprint() to print results.
5561
56-
Call DStream.print().
62+
Call DStream.print() and this function will print byte array in the DStream
5763
"""
5864
# a hack to call print function in DStream
5965
getattr(self._jdstream, "print")(label)
@@ -63,29 +69,32 @@ def filter(self, f):
6369
Return a new DStream containing only the elements that satisfy predicate.
6470
"""
6571
def func(iterator): return ifilter(f, iterator)
66-
return self._mapPartitions(func)
72+
return self.mapPartitions(func)
6773

6874
def flatMap(self, f, preservesPartitioning=False):
6975
"""
7076
Pass each value in the key-value pair DStream through flatMap function
7177
without changing the keys: this also retains the original RDD's partition.
7278
"""
73-
def func(s, iterator): return chain.from_iterable(imap(f, iterator))
79+
def func(s, iterator):
80+
return chain.from_iterable(imap(f, iterator))
7481
return self._mapPartitionsWithIndex(func, preservesPartitioning)
7582

76-
def map(self, f):
83+
def map(self, f, preservesPartitioning=False):
7784
"""
7885
Return a new DStream by applying a function to each element of DStream.
7986
"""
80-
def func(iterator): return imap(f, iterator)
81-
return self._mapPartitions(func)
87+
def func(iterator):
88+
return imap(f, iterator)
89+
return self.mapPartitions(func, preservesPartitioning)
8290

83-
def _mapPartitions(self, f):
91+
def mapPartitions(self, f, preservesPartitioning=False):
8492
"""
8593
Return a new DStream by applying a function to each partition of this DStream.
8694
"""
87-
def func(s, iterator): return f(iterator)
88-
return self._mapPartitionsWithIndex(func)
95+
def func(s, iterator):
96+
return f(iterator)
97+
return self._mapPartitionsWithIndex(func, preservesPartitioning)
8998

9099
def _mapPartitionsWithIndex(self, f, preservesPartitioning=False):
91100
"""
@@ -131,7 +140,7 @@ def combineLocally(iterator):
131140
else:
132141
combiners[k] = mergeValue(combiners[k], v)
133142
return combiners.iteritems()
134-
locally_combined = self._mapPartitions(combineLocally)
143+
locally_combined = self.mapPartitions(combineLocally)
135144
shuffled = locally_combined.partitionBy(numPartitions)
136145

137146
def _mergeCombiners(iterator):
@@ -143,7 +152,7 @@ def _mergeCombiners(iterator):
143152
combiners[k] = mergeCombiners(combiners[k], v)
144153
return combiners.iteritems()
145154

146-
return shuffled._mapPartitions(_mergeCombiners)
155+
return shuffled.mapPartitions(_mergeCombiners)
147156

148157
def partitionBy(self, numPartitions, partitionFunc=None):
149158
"""
@@ -246,6 +255,34 @@ def takeAndPrint(rdd, time):
246255

247256
self.foreachRDD(takeAndPrint)
248257

258+
def mapValues(self, f):
259+
"""
260+
Pass each value in the key-value pair RDD through a map function
261+
without changing the keys; this also retains the original RDD's
262+
partitioning.
263+
"""
264+
map_values_fn = lambda (k, v): (k, f(v))
265+
return self.map(map_values_fn, preservesPartitioning=True)
266+
267+
def flatMapValues(self, f):
268+
"""
269+
Pass each value in the key-value pair RDD through a flatMap function
270+
without changing the keys; this also retains the original RDD's
271+
partitioning.
272+
"""
273+
flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
274+
return self.flatMap(flat_map_fn, preservesPartitioning=True)
275+
276+
def glom(self):
277+
"""
278+
Return a new DStream in which RDD is generated by applying glom() to RDD of
279+
this DStream. Applying glom() to an RDD coalesces all elements within each partition into
280+
an list.
281+
"""
282+
def func(iterator):
283+
yield list(iterator)
284+
return self.mapPartitions(func)
285+
249286
#def transform(self, func): - TD
250287
# from utils import RDDFunction
251288
# wrapped_func = RDDFunction(self.ctx, self._jrdd_deserializer, func)
@@ -255,7 +292,7 @@ def takeAndPrint(rdd, time):
255292
def _test_output(self, result):
256293
"""
257294
This function is only for test case.
258-
Store data in a DStream to result to verify the result in tese case
295+
Store data in a DStream to result to verify the result in test case
259296
"""
260297
def get_output(rdd, time):
261298
taken = rdd.collect()
@@ -318,4 +355,4 @@ def _jdstream(self):
318355
return self._jdstream_val
319356

320357
def _is_pipelinable(self):
321-
return not (self.is_cached)
358+
return not self.is_cached

python/pyspark/streaming_tests.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,54 @@ def test_func(dstream):
142142
output = self._run_stream(test_input, test_func, expected_output)
143143
self.assertEqual(expected_output, output)
144144

145-
def _run_stream(self, test_input, test_func, expected_output):
145+
def test_mapValues(self):
146+
"""Basic operation test for DStream.mapValues"""
147+
test_input = [["a", "a", "b"], ["", ""], []]
148+
149+
def test_func(dstream):
150+
return dstream.map(lambda x: (x, 1)).reduceByKey(operator.add).mapValues(lambda x: x + 10)
151+
expected_output = [[("a", 12), ("b", 11)], [("", 12)], []]
152+
output = self._run_stream(test_input, test_func, expected_output)
153+
self.assertEqual(expected_output, output)
154+
155+
def test_flatMapValues(self):
156+
"""Basic operation test for DStream.flatMapValues"""
157+
test_input = [["a", "a", "b"], ["", ""], []]
158+
159+
def test_func(dstream):
160+
return dstream.map(lambda x: (x, 1)).reduceByKey(operator.add).flatMapValues(lambda x: (x, x + 10))
161+
expected_output = [[("a", 2), ("a", 12), ("b", 1), ("b", 11)], [("", 2), ("", 12)], []]
162+
output = self._run_stream(test_input, test_func, expected_output)
163+
self.assertEqual(expected_output, output)
164+
165+
def test_glom(self):
166+
"""Basic operation test for DStream.glom"""
167+
test_input = [range(1, 5), range(5, 9), range(9, 13)]
168+
numSlices = 2
169+
170+
def test_func(dstream):
171+
dstream.pyprint()
172+
return dstream.glom()
173+
expected_output = [[[1,2], [3,4]],[[5,6], [7,8]],[[9,10], [11,12]]]
174+
output = self._run_stream(test_input, test_func, expected_output, numSlices)
175+
self.assertEqual(expected_output, output)
176+
177+
def test_mapPartitions(self):
178+
"""Basic operation test for DStream.mapPartitions"""
179+
test_input = [range(1, 5), range(5, 9), range(9, 13)]
180+
numSlices = 2
181+
182+
def test_func(dstream):
183+
dstream.pyprint()
184+
return dstream.mapPartitions(lambda x: reduce(operator.add, x))
185+
expected_output = [[3, 7],[11, 15],[19, 23]]
186+
output = self._run_stream(test_input, test_func, expected_output, numSlices)
187+
self.assertEqual(expected_output, output)
188+
189+
def _run_stream(self, test_input, test_func, expected_output, numSlices=None):
146190
"""Start stream and return the output"""
147191
# Generate input stream with user-defined input
148-
test_input_stream = self.ssc._testInputStream(test_input)
192+
test_input_stream = self.ssc._testInputStream(test_input, numSlices)
149193
# Apply test function to stream
150194
test_stream = test_func(test_input_stream)
151195
# Add job to get output from stream

0 commit comments

Comments
 (0)