Skip to content

Commit e108ec1

Browse files
committed
address comments
1 parent 37fe06f commit e108ec1

File tree

6 files changed

+80
-86
lines changed

6 files changed

+80
-86
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collectio
2525
import scala.collection.JavaConversions._
2626
import scala.collection.mutable
2727
import scala.language.existentials
28-
import scala.reflect.ClassTag
29-
import scala.util.{Try, Success, Failure}
3028

3129
import net.razorvine.pickle.{Pickler, Unpickler}
3230

@@ -52,12 +50,6 @@ private[spark] class PythonRDD(
5250
accumulator: Accumulator[JList[Array[Byte]]])
5351
extends RDD[Array[Byte]](parent) {
5452

55-
// create a new PythonRDD with same Python setting but different parent.
56-
def copyTo(rdd: RDD[_]): PythonRDD = {
57-
new PythonRDD(rdd, command, envVars, pythonIncludes, preservePartitoning,
58-
pythonExec, broadcastVars, accumulator)
59-
}
60-
6153
val bufferSize = conf.getInt("spark.buffer.size", 65536)
6254
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
6355

python/pyspark/rdd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,8 @@ def sum(self):
787787
>>> sc.parallelize([1.0, 2.0, 3.0]).sum()
788788
6.0
789789
"""
790+
if not self.getNumPartitions():
791+
return 0 # empty RDD can not been reduced
790792
return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add)
791793

792794
def count(self):

python/pyspark/streaming/context.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,18 @@ class StreamingContext(object):
8484
"""
8585
_transformerSerializer = None
8686

87-
def __init__(self, sparkContext, duration=None, jssc=None):
87+
def __init__(self, sparkContext, batchDuration=None, jssc=None):
8888
"""
8989
Create a new StreamingContext.
9090
9191
@param sparkContext: L{SparkContext} object.
92-
@param duration: number of seconds.
92+
@param batchDuration: the time interval (in seconds) at which streaming
93+
data will be divided into batches
9394
"""
9495

9596
self._sc = sparkContext
9697
self._jvm = self._sc._jvm
97-
self._jssc = jssc or self._initialize_context(self._sc, duration)
98+
self._jssc = jssc or self._initialize_context(self._sc, batchDuration)
9899

99100
def _initialize_context(self, sc, duration):
100101
self._ensure_initialized()
@@ -134,26 +135,27 @@ def _ensure_initialized(cls):
134135
SparkContext._active_spark_context, CloudPickleSerializer(), gw)
135136

136137
@classmethod
137-
def getOrCreate(cls, path, setupFunc):
138+
def getOrCreate(cls, checkpointPath, setupFunc):
138139
"""
139-
Get the StreamingContext from checkpoint file at `path`, or setup
140-
it by `setupFunc`.
140+
Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
141+
If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
142+
recreated from the checkpoint data. If the data does not exist, then the provided setupFunc
143+
will be used to create a JavaStreamingContext.
141144
142-
:param path: directory of checkpoint
143-
:param setupFunc: a function used to create StreamingContext and
144-
setup DStreams.
145-
:return: a StreamingContext
145+
@param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
146+
@param setupFunc Function to create a new JavaStreamingContext and setup DStreams
146147
"""
147-
if not os.path.exists(path) or not os.path.isdir(path) or not os.listdir(path):
148+
# TODO: support checkpoint in HDFS
149+
if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath):
148150
ssc = setupFunc()
149-
ssc.checkpoint(path)
151+
ssc.checkpoint(checkpointPath)
150152
return ssc
151153

152154
cls._ensure_initialized()
153155
gw = SparkContext._gateway
154156

155157
try:
156-
jssc = gw.jvm.JavaStreamingContext(path)
158+
jssc = gw.jvm.JavaStreamingContext(checkpointPath)
157159
except Exception:
158160
print >>sys.stderr, "failed to load StreamingContext from checkpoint"
159161
raise
@@ -249,12 +251,12 @@ def textFileStream(self, directory):
249251
"""
250252
return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer())
251253

252-
def _check_serialzers(self, rdds):
254+
def _check_serializers(self, rdds):
253255
# make sure they have same serializer
254256
if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1:
255257
for i in range(len(rdds)):
256258
# reset them to sc.serializer
257-
rdds[i] = rdds[i].map(lambda x: x, preservesPartitioning=True)
259+
rdds[i] = rdds[i]._reserialize()
258260

259261
def queueStream(self, rdds, oneAtATime=True, default=None):
260262
"""
@@ -275,7 +277,7 @@ def queueStream(self, rdds, oneAtATime=True, default=None):
275277

276278
if rdds and not isinstance(rdds[0], RDD):
277279
rdds = [self._sc.parallelize(input) for input in rdds]
278-
self._check_serialzers(rdds)
280+
self._check_serializers(rdds)
279281

280282
jrdds = ListConverter().convert([r._jrdd for r in rdds],
281283
SparkContext._gateway._gateway_client)
@@ -313,6 +315,10 @@ def union(self, *dstreams):
313315
raise ValueError("should have at least one DStream to union")
314316
if len(dstreams) == 1:
315317
return dstreams[0]
318+
if len(set(s._jrdd_deserializer for s in dstreams)) > 1:
319+
raise ValueError("All DStreams should have same serializer")
320+
if len(set(s._slideDuration for s in dstreams)) > 1:
321+
raise ValueError("All DStreams should have same slide duration")
316322
first = dstreams[0]
317323
jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]],
318324
SparkContext._gateway._gateway_client)

0 commit comments

Comments
 (0)