Skip to content

Commit f76c182

Browse files
committed
remove waste duplicated code
1 parent 18c8723 commit f76c182

File tree

2 files changed

+56
-62
lines changed

2 files changed

+56
-62
lines changed

python/pyspark/streaming/context.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -130,48 +130,7 @@ def stop(self, stopSparkContext=True, stopGraceFully=False):
130130
# Stop Callback server
131131
SparkContext._gateway.shutdown()
132132

133-
def checkpoint(self, directory):
134-
"""
135-
Not tested
136-
"""
137-
self._jssc.checkpoint(directory)
138-
139133
def _testInputStream(self, test_inputs, numSlices=None):
140-
"""
141-
Generate multiple files to make "stream" in Scala side for test.
142-
Scala chooses one of the files and generates RDD using PythonRDD.readRDDFromFile.
143-
144-
QueStream maybe good way to implement this function
145-
"""
146-
numSlices = numSlices or self._sc.defaultParallelism
147-
# Calling the Java parallelize() method with an ArrayList is too slow,
148-
# because it sends O(n) Py4J commands. As an alternative, serialized
149-
# objects are written to a file and loaded through textFile().
150-
151-
tempFiles = list()
152-
for test_input in test_inputs:
153-
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
154-
155-
# Make sure we distribute data evenly if it's smaller than self.batchSize
156-
if "__len__" not in dir(test_input):
157-
test_input = list(test_input) # Make it a list so we can compute its length
158-
batchSize = min(len(test_input) // numSlices, self._sc._batchSize)
159-
if batchSize > 1:
160-
serializer = BatchedSerializer(self._sc._unbatched_serializer,
161-
batchSize)
162-
else:
163-
serializer = self._sc._unbatched_serializer
164-
serializer.dump_stream(test_input, tempFile)
165-
tempFile.close()
166-
tempFiles.append(tempFile.name)
167-
168-
jtempFiles = ListConverter().convert(tempFiles, SparkContext._gateway._gateway_client)
169-
jinput_stream = self._jvm.PythonTestInputStream(self._jssc,
170-
jtempFiles,
171-
numSlices).asJavaDStream()
172-
return DStream(jinput_stream, self, BatchedSerializer(PickleSerializer()))
173-
174-
def _testInputStream2(self, test_inputs, numSlices=None):
175134
"""
176135
This is inpired by QueStream implementation. Give list of RDD and generate DStream
177136
which contain the RDD.
@@ -184,7 +143,7 @@ def _testInputStream2(self, test_inputs, numSlices=None):
184143
test_rdd_deserializers.append(test_rdd._jrdd_deserializer)
185144

186145
jtest_rdds = ListConverter().convert(test_rdds, SparkContext._gateway._gateway_client)
187-
jinput_stream = self._jvm.PythonTestInputStream2(self._jssc, jtest_rdds).asJavaDStream()
146+
jinput_stream = self._jvm.PythonTestInputStream(self._jssc, jtest_rdds).asJavaDStream()
188147

189148
dstream = DStream(jinput_stream, self, test_rdd_deserializers[0])
190149
return dstream

python/pyspark/streaming/dstream.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
from collections import defaultdict
1919
from itertools import chain, ifilter, imap
20-
import time
2120
import operator
2221

2322
from pyspark.serializers import NoOpSerializer,\
2423
BatchedSerializer, CloudPickleSerializer, pack_long
2524
from pyspark.rdd import _JavaStackTrace
25+
from pyspark.storagelevel import StorageLevel
26+
from pyspark.resultiterable import ResultIterable
2627

2728
from py4j.java_collections import ListConverter, MapConverter
2829

@@ -35,6 +36,8 @@ def __init__(self, jdstream, ssc, jrdd_deserializer):
3536
self._ssc = ssc
3637
self.ctx = ssc._sc
3738
self._jrdd_deserializer = jrdd_deserializer
39+
self.is_cached = False
40+
self.is_checkpointed = False
3841

3942
def context(self):
4043
"""
@@ -247,8 +250,6 @@ def takeAndPrint(rdd, time):
247250
taken = rdd.take(11)
248251
print "-------------------------------------------"
249252
print "Time: %s" % (str(time))
250-
print rdd.glom().collect()
251-
print "-------------------------------------------"
252253
print "-------------------------------------------"
253254
for record in taken[:10]:
254255
print record
@@ -303,32 +304,65 @@ def get_output(rdd, time):
303304

304305
self.foreachRDD(get_output)
305306

306-
def _test_switch_dserializer(self, serializer_que):
307+
def cache(self):
308+
"""
309+
Persist this DStream with the default storage level (C{MEMORY_ONLY_SER}).
310+
"""
311+
self.is_cached = True
312+
self.persist(StorageLevel.MEMORY_ONLY_SER)
313+
return self
314+
315+
def persist(self, storageLevel):
316+
"""
317+
Set this DStream's storage level to persist its values across operations
318+
after the first time it is computed. This can only be used to assign
319+
a new storage level if the DStream does not have a storage level set yet.
320+
"""
321+
self.is_cached = True
322+
javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
323+
self._jdstream.persist(javaStorageLevel)
324+
return self
325+
326+
def checkpoint(self, interval):
307327
"""
308-
Deserializer is dynamically changed based on numSlice and the number of
309-
input. This function choose deserializer. Currently this is just FIFO.
328+
Mark this DStream for checkpointing. It will be saved to a file inside the
329+
checkpoint directory set with L{SparkContext.setCheckpointDir()}
330+
331+
I am not sure this part in DStream
332+
and
333+
all references to its parent RDDs will be removed. This function must
334+
be called before any job has been executed on this RDD. It is strongly
335+
recommended that this RDD is persisted in memory, otherwise saving it
336+
on a file will require recomputation.
337+
338+
interval must be pysprak.streaming.duration
310339
"""
311-
312-
jrdd_deserializer = self._jrdd_deserializer
340+
self.is_checkpointed = True
341+
self._jdstream.checkpoint(interval)
342+
return self
343+
344+
def groupByKey(self, numPartitions=None):
345+
def createCombiner(x):
346+
return [x]
313347

314-
def switch(rdd, jtime):
315-
try:
316-
print serializer_que
317-
jrdd_deserializer = serializer_que.pop(0)
318-
print jrdd_deserializer
319-
except Exception as e:
320-
print e
348+
def mergeValue(xs, x):
349+
xs.append(x)
350+
return xs
321351

322-
self.foreachRDD(switch)
352+
def mergeCombiners(a, b):
353+
a.extend(b)
354+
return a
323355

356+
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
357+
numPartitions).mapValues(lambda x: ResultIterable(x))
324358

325359

326360
# TODO: implement groupByKey
361+
# TODO: implement saveAsTextFile
362+
363+
# Following operation has dependency to transform
327364
# TODO: impelment union
328-
# TODO: implement cache
329-
# TODO: implement persist
330365
# TODO: implement repertitions
331-
# TODO: implement saveAsTextFile
332366
# TODO: implement cogroup
333367
# TODO: implement join
334368
# TODO: implement countByValue
@@ -355,6 +389,7 @@ def pipeline_func(split, iterator):
355389
self._prev_jdstream = prev._prev_jdstream # maintain the pipeline
356390
self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
357391
self.is_cached = False
392+
self.is_checkpointed = False
358393
self._ssc = prev._ssc
359394
self.ctx = prev.ctx
360395
self.prev = prev
@@ -391,4 +426,4 @@ def _jdstream(self):
391426
return self._jdstream_val
392427

393428
def _is_pipelinable(self):
394-
return not self.is_cached
429+
return not (self.is_cached or self.is_checkpointed)

0 commit comments

Comments
 (0)