Skip to content

Commit d328aca

Browse files
committed
fix serializer in queueStream
1 parent 6f0da2f commit d328aca

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

python/pyspark/streaming/context.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -238,29 +238,37 @@ def textFileStream(self, directory):
238238

239239
def _check_serialzers(self, rdds):
240240
# make sure they have same serializer
241-
if len(set(rdd._jrdd_deserializer for rdd in rdds)):
241+
if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1:
242242
for i in range(len(rdds)):
243243
# reset them to sc.serializer
244244
rdds[i] = rdds[i].map(lambda x: x, preservesPartitioning=True)
245245

246-
def queueStream(self, queue, oneAtATime=True, default=None):
246+
def queueStream(self, rdds, oneAtATime=True, default=None):
247247
"""
248248
Create an input stream from an queue of RDDs or list. In each batch,
249249
it will process either one or all of the RDDs returned by the queue.
250250
251251
NOTE: changes to the queue after the stream is created will not be recognized.
252-
@param queue Queue of RDDs
253-
@tparam T Type of objects in the RDD
252+
253+
@param rdds Queue of RDDs
254+
@param oneAtATime pick one rdd each time or pick all of them once.
255+
@param default The default rdd if no more in rdds
254256
"""
255-
if queue and not isinstance(queue[0], RDD):
256-
rdds = [self._sc.parallelize(input) for input in queue]
257-
else:
258-
rdds = queue
257+
if default and not isinstance(default, RDD):
258+
default = self._sc.parallelize(default)
259+
260+
if not rdds and default:
261+
rdds = [rdds]
262+
263+
if rdds and not isinstance(rdds[0], RDD):
264+
rdds = [self._sc.parallelize(input) for input in rdds]
259265
self._check_serialzers(rdds)
266+
260267
jrdds = ListConverter().convert([r._jrdd for r in rdds],
261268
SparkContext._gateway._gateway_client)
262269
queue = self._jvm.PythonDStream.toRDDQueue(jrdds)
263270
if default:
271+
default = default._reserialize(rdds[0]._jrdd_deserializer)
264272
jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
265273
else:
266274
jdstream = self._jssc.queueStream(queue, oneAtATime)

python/pyspark/streaming/dstream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def transformWith(self, func, other, keepSerializer=False):
292292
oldfunc = func
293293
func = lambda t, a, b: oldfunc(a, b)
294294
assert func.func_code.co_argcount == 3, "func should take two or three arguments"
295-
jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer)
295+
jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer, other._jrdd_deserializer)
296296
dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(),
297297
other._jdstream.dstream(), jfunc)
298298
jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer

python/pyspark/streaming/tests.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -508,16 +508,16 @@ def setup():
508508
conf = SparkConf().set("spark.default.parallelism", 1)
509509
sc = SparkContext(conf=conf)
510510
ssc = StreamingContext(sc, .2)
511-
rdd = sc.parallelize(range(10), 1)
511+
rdd = sc.parallelize(range(1), 1)
512512
dstream = ssc.queueStream([rdd], default=rdd)
513-
result[0] = self._collect(dstream.countByWindow(1, .2))
513+
result[0] = self._collect(dstream.countByWindow(1, 0.2))
514514
return ssc
515515
tmpd = tempfile.mkdtemp("test_streaming_cps")
516516
ssc = StreamingContext.getOrCreate(tmpd, setup)
517517
ssc.start()
518518
ssc.awaitTermination(4)
519519
ssc.stop()
520-
expected = [[i * 10 + 10] for i in range(5)] + [[50]] * 5
520+
expected = [[i * 1 + 1] for i in range(5)] + [[5]] * 5
521521
self.assertEqual(expected, result[0][:10])
522522

523523
ssc = StreamingContext.getOrCreate(tmpd, setup)

0 commit comments

Comments
 (0)