Skip to content

Commit dce8908

Browse files
xuanyuankingJackey Lee
authored andcommitted
[SPARK-26549][PYSPARK] Fix for python worker reuse take no effect for parallelize lazy iterable range
## What changes were proposed in this pull request? During the follow-up work(apache#23435) for PySpark worker reuse scenario, we found that the worker reuse takes no effect for `sc.parallelize(xrange(...))`. It happened because of the specialize rdd.parallelize logic for xrange(introduced in apache#3264) generated data by lazy iterable range, which don't need to use the passed-in iterator. But this will break the end of stream checking in python worker and finally cause worker reuse takes no effect. See more details in [SPARK-26549](https://issues.apache.org/jira/browse/SPARK-26549) description. We fix this by force using the passed-in iterator. ## How was this patch tested? New UT in test_worker.py. Closes apache#23470 from xuanyuanking/SPARK-26549. Authored-by: Yuanjian Li <xyliyuanjian@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 1a794bf commit dce8908

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

python/pyspark/context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,14 @@ def getStart(split):
498498
return start0 + int((split * size / numSlices)) * step
499499

500500
def f(split, iterator):
501+
# it's an empty iterator here but we need this line for triggering the
502+
# logic of signal handling in FramedSerializer.load_stream, for instance,
503+
# SpecialLengths.END_OF_DATA_SECTION in _read_with_length. Since
504+
# FramedSerializer.load_stream produces a generator, the control should
505+
# at least be in that function once. Here we do it by explicitly converting
506+
# the empty iterator to a list, thus make sure worker reuse takes effect.
507+
# See more details in SPARK-26549.
508+
assert len(list(iterator)) == 0
501509
return xrange(getStart(split), getStart(split + 1), step)
502510

503511
return self.parallelize([], numSlices).mapPartitionsWithIndex(f)

python/pyspark/tests/test_worker.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from py4j.protocol import Py4JJavaError
2424

25-
from pyspark.testing.utils import ReusedPySparkTestCase, QuietTest
25+
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest
2626

2727
if sys.version_info[0] >= 3:
2828
xrange = range
@@ -145,6 +145,16 @@ def test_with_different_versions_of_python(self):
145145
self.sc.pythonVer = version
146146

147147

148+
class WorkerReuseTest(PySparkTestCase):
149+
150+
def test_reuse_worker_of_parallelize_xrange(self):
151+
rdd = self.sc.parallelize(xrange(20), 8)
152+
previous_pids = rdd.map(lambda x: os.getpid()).collect()
153+
current_pids = rdd.map(lambda x: os.getpid()).collect()
154+
for pid in current_pids:
155+
self.assertTrue(pid in previous_pids)
156+
157+
148158
if __name__ == "__main__":
149159
import unittest
150160
from pyspark.tests.test_worker import *

0 commit comments

Comments
 (0)