Skip to content

Commit 94f2b65

Browse files
committed
remove waste duplicated code
1 parent 580fbc2 commit 94f2b65

File tree

2 files changed

+56
-62
lines changed

2 files changed

+56
-62
lines changed

python/pyspark/streaming/context.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -130,48 +130,7 @@ def stop(self, stopSparkContext=True, stopGraceFully=False):
130130
# Stop Callback server
131131
SparkContext._gateway.shutdown()
132132

133-
def checkpoint(self, directory):
134-
"""
135-
Not tested
136-
"""
137-
self._jssc.checkpoint(directory)
138-
139133
def _testInputStream(self, test_inputs, numSlices=None):
140-
"""
141-
Generate multiple files to make "stream" in Scala side for test.
142-
Scala chooses one of the files and generates RDD using PythonRDD.readRDDFromFile.
143-
144-
QueStream maybe good way to implement this function
145-
"""
146-
numSlices = numSlices or self._sc.defaultParallelism
147-
# Calling the Java parallelize() method with an ArrayList is too slow,
148-
# because it sends O(n) Py4J commands. As an alternative, serialized
149-
# objects are written to a file and loaded through textFile().
150-
151-
tempFiles = list()
152-
for test_input in test_inputs:
153-
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
154-
155-
# Make sure we distribute data evenly if it's smaller than self.batchSize
156-
if "__len__" not in dir(test_input):
157-
test_input = list(test_input) # Make it a list so we can compute its length
158-
batchSize = min(len(test_input) // numSlices, self._sc._batchSize)
159-
if batchSize > 1:
160-
serializer = BatchedSerializer(self._sc._unbatched_serializer,
161-
batchSize)
162-
else:
163-
serializer = self._sc._unbatched_serializer
164-
serializer.dump_stream(test_input, tempFile)
165-
tempFile.close()
166-
tempFiles.append(tempFile.name)
167-
168-
jtempFiles = ListConverter().convert(tempFiles, SparkContext._gateway._gateway_client)
169-
jinput_stream = self._jvm.PythonTestInputStream(self._jssc,
170-
jtempFiles,
171-
numSlices).asJavaDStream()
172-
return DStream(jinput_stream, self, BatchedSerializer(PickleSerializer()))
173-
174-
def _testInputStream2(self, test_inputs, numSlices=None):
175134
"""
176135
This is inpired by QueStream implementation. Give list of RDD and generate DStream
177136
which contain the RDD.
@@ -184,7 +143,7 @@ def _testInputStream2(self, test_inputs, numSlices=None):
184143
test_rdd_deserializers.append(test_rdd._jrdd_deserializer)
185144

186145
jtest_rdds = ListConverter().convert(test_rdds, SparkContext._gateway._gateway_client)
187-
jinput_stream = self._jvm.PythonTestInputStream2(self._jssc, jtest_rdds).asJavaDStream()
146+
jinput_stream = self._jvm.PythonTestInputStream(self._jssc, jtest_rdds).asJavaDStream()
188147

189148
dstream = DStream(jinput_stream, self, test_rdd_deserializers[0])
190149
return dstream

python/pyspark/streaming/dstream.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
from collections import defaultdict
1919
from itertools import chain, ifilter, imap
20-
import time
2120
import operator
2221

2322
from pyspark.serializers import NoOpSerializer,\
2423
BatchedSerializer, CloudPickleSerializer, pack_long
2524
from pyspark.rdd import _JavaStackTrace
25+
from pyspark.storagelevel import StorageLevel
26+
from pyspark.resultiterable import ResultIterable
2627

2728
from py4j.java_collections import ListConverter, MapConverter
2829

@@ -35,6 +36,8 @@ def __init__(self, jdstream, ssc, jrdd_deserializer):
3536
self._ssc = ssc
3637
self.ctx = ssc._sc
3738
self._jrdd_deserializer = jrdd_deserializer
39+
self.is_cached = False
40+
self.is_checkpointed = False
3841

3942
def context(self):
4043
"""
@@ -234,8 +237,6 @@ def takeAndPrint(rdd, time):
234237
taken = rdd.take(11)
235238
print "-------------------------------------------"
236239
print "Time: %s" % (str(time))
237-
print rdd.glom().collect()
238-
print "-------------------------------------------"
239240
print "-------------------------------------------"
240241
for record in taken[:10]:
241242
print record
@@ -290,32 +291,65 @@ def get_output(rdd, time):
290291

291292
self.foreachRDD(get_output)
292293

293-
def _test_switch_dserializer(self, serializer_que):
294+
def cache(self):
295+
"""
296+
Persist this DStream with the default storage level (C{MEMORY_ONLY_SER}).
297+
"""
298+
self.is_cached = True
299+
self.persist(StorageLevel.MEMORY_ONLY_SER)
300+
return self
301+
302+
def persist(self, storageLevel):
303+
"""
304+
Set this DStream's storage level to persist its values across operations
305+
after the first time it is computed. This can only be used to assign
306+
a new storage level if the DStream does not have a storage level set yet.
307+
"""
308+
self.is_cached = True
309+
javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
310+
self._jdstream.persist(javaStorageLevel)
311+
return self
312+
313+
def checkpoint(self, interval):
294314
"""
295-
Deserializer is dynamically changed based on numSlice and the number of
296-
input. This function choose deserializer. Currently this is just FIFO.
315+
Mark this DStream for checkpointing. It will be saved to a file inside the
316+
checkpoint directory set with L{SparkContext.setCheckpointDir()}
317+
318+
I am not sure this part in DStream
319+
and
320+
all references to its parent RDDs will be removed. This function must
321+
be called before any job has been executed on this RDD. It is strongly
322+
recommended that this RDD is persisted in memory, otherwise saving it
323+
on a file will require recomputation.
324+
325+
interval must be pysprak.streaming.duration
297326
"""
298-
299-
jrdd_deserializer = self._jrdd_deserializer
327+
self.is_checkpointed = True
328+
self._jdstream.checkpoint(interval)
329+
return self
330+
331+
def groupByKey(self, numPartitions=None):
332+
def createCombiner(x):
333+
return [x]
300334

301-
def switch(rdd, jtime):
302-
try:
303-
print serializer_que
304-
jrdd_deserializer = serializer_que.pop(0)
305-
print jrdd_deserializer
306-
except Exception as e:
307-
print e
335+
def mergeValue(xs, x):
336+
xs.append(x)
337+
return xs
308338

309-
self.foreachRDD(switch)
339+
def mergeCombiners(a, b):
340+
a.extend(b)
341+
return a
310342

343+
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
344+
numPartitions).mapValues(lambda x: ResultIterable(x))
311345

312346

313347
# TODO: implement groupByKey
348+
# TODO: implement saveAsTextFile
349+
350+
# Following operation has dependency to transform
314351
# TODO: impelment union
315-
# TODO: implement cache
316-
# TODO: implement persist
317352
# TODO: implement repertitions
318-
# TODO: implement saveAsTextFile
319353
# TODO: implement cogroup
320354
# TODO: implement join
321355
# TODO: implement countByValue
@@ -342,6 +376,7 @@ def pipeline_func(split, iterator):
342376
self._prev_jdstream = prev._prev_jdstream # maintain the pipeline
343377
self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
344378
self.is_cached = False
379+
self.is_checkpointed = False
345380
self._ssc = prev._ssc
346381
self.ctx = prev.ctx
347382
self.prev = prev
@@ -378,4 +413,4 @@ def _jdstream(self):
378413
return self._jdstream_val
379414

380415
def _is_pipelinable(self):
381-
return not self.is_cached
416+
return not (self.is_cached or self.is_checkpointed)

0 commit comments

Comments
 (0)