-
Notifications
You must be signed in to change notification settings - Fork 28.8k
[SPARK-23961][SPARK-27548][PYTHON] Fix error when toLocalIterator goes out of scope and properly raise errors from worker #24070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
899ad8d
8c309c5
866d585
9ad3a77
3415ff1
57d251c
600a906
0a796d7
7847a14
a1f811a
29b8ab6
4f842dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -163,8 +163,63 @@ private[spark] object PythonRDD extends Logging { | |
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") | ||
} | ||
|
||
/** | ||
* A helper function to create a local RDD iterator and serve it via socket. Partitions are | ||
* are collected as separate jobs, by order of index. Partition data is first requested by a | ||
* non-zero integer to start a collection job. The response is prefaced by an integer with 1 | ||
* meaning partition data will be served, 0 meaning the local iterator has been consumed, | ||
* and -1 meaining an error occurred during collection. This function is used by | ||
* pyspark.rdd._local_iterator_from_socket(). | ||
* | ||
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the | ||
* data collected from these jobs, and the secret for authentication. | ||
*/ | ||
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { | ||
serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") | ||
val (port, secret) = SocketAuthServer.setupOneConnectionServer( | ||
authHelper, "serve toLocalIterator") { s => | ||
val out = new DataOutputStream(s.getOutputStream) | ||
val in = new DataInputStream(s.getInputStream) | ||
Utils.tryWithSafeFinally { | ||
|
||
// Collects a partition on each iteration | ||
val collectPartitionIter = rdd.partitions.indices.iterator.map { i => | ||
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: this is the same function as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For performance, as mentioned in my questions, would it make sense to use something like a iterator with look ahead of say 1 partition (or X% of partitions) so we decrease the blocking time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that would have better performance, but it does say in the doc that max memory usage will be the largest partition. Going over that might cause problems for some people, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good point @BryanCutler, we could implement the lookahead as a seperate PR/JIRA and allow it to be turned-on/off. I'd suggest this PR is more about fixing the behaviour of toLocalIterator memory wise than the out-of-scope issue (although the out-of-scope issue is maybe more visible in the logs). |
||
} | ||
|
||
// Read request for data and send next partition if nonzero | ||
var complete = false | ||
while (!complete && in.readInt() != 0) { | ||
if (collectPartitionIter.hasNext) { | ||
try { | ||
// Attempt to collect the next partition | ||
val partitionArray = collectPartitionIter.next() | ||
|
||
// Send response there is a partition to read | ||
out.writeInt(1) | ||
|
||
// Write the next object and signal end of data for this iteration | ||
writeIteratorToStream(partitionArray.toIterator, out) | ||
out.writeInt(SpecialLengths.END_OF_DATA_SECTION) | ||
out.flush() | ||
} catch { | ||
case e: SparkException => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We want to catch any errors during the collection job, so I believe the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's reasonable +1 |
||
// Send response that an error occurred followed by error message | ||
out.writeInt(-1) | ||
writeUTF(e.getMessage, out) | ||
complete = true | ||
} | ||
} else { | ||
// Send response there are no more partitions to read and close | ||
out.writeInt(0) | ||
complete = true | ||
} | ||
} | ||
} { | ||
out.close() | ||
in.close() | ||
} | ||
} | ||
Array(port, secret) | ||
} | ||
|
||
def readRDDFromFile( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,9 +39,9 @@ | |
from itertools import imap as map, ifilter as filter | ||
|
||
from pyspark.java_gateway import local_connect_and_auth | ||
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ | ||
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ | ||
PickleSerializer, pack_long, AutoBatchedSerializer | ||
from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, NoOpSerializer, \ | ||
CartesianDeserializer, CloudPickleSerializer, PairDeserializer, PickleSerializer, \ | ||
UTF8Deserializer, pack_long, read_int, write_int | ||
from pyspark.join import python_join, python_left_outer_join, \ | ||
python_right_outer_join, python_full_outer_join, python_cogroup | ||
from pyspark.statcounter import StatCounter | ||
|
@@ -138,15 +138,69 @@ def _parse_memory(s): | |
return int(float(s[:-1]) * units[s[-1].lower()]) | ||
|
||
|
||
def _load_from_socket(sock_info, serializer): | ||
def _create_local_socket(sock_info): | ||
(sockfile, sock) = local_connect_and_auth(*sock_info) | ||
# The RDD materialization time is unpredicable, if we set a timeout for socket reading | ||
# The RDD materialization time is unpredictable, if we set a timeout for socket reading | ||
# operation, it will very possibly fail. See SPARK-18281. | ||
sock.settimeout(None) | ||
return sockfile | ||
|
||
|
||
def _load_from_socket(sock_info, serializer): | ||
sockfile = _create_local_socket(sock_info) | ||
# The socket will be automatically closed when garbage-collected. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess that this comment should be moved above in |
||
return serializer.load_stream(sockfile) | ||
|
||
|
||
def _local_iterator_from_socket(sock_info, serializer): | ||
|
||
class PyLocalIterable(object): | ||
""" Create a synchronous local iterable over a socket """ | ||
|
||
def __init__(self, _sock_info, _serializer): | ||
self._sockfile = _create_local_socket(_sock_info) | ||
self._serializer = _serializer | ||
self._read_iter = iter([]) # Initialize as empty iterator | ||
self._read_status = 1 | ||
|
||
def __iter__(self): | ||
while self._read_status == 1: | ||
# Request next partition data from Java | ||
write_int(1, self._sockfile) | ||
self._sockfile.flush() | ||
|
||
# If response is 1 then there is a partition to read, if 0 then fully consumed | ||
self._read_status = read_int(self._sockfile) | ||
if self._read_status == 1: | ||
|
||
# Load the partition data as a stream and read each item | ||
self._read_iter = self._serializer.load_stream(self._sockfile) | ||
for item in self._read_iter: | ||
yield item | ||
|
||
# An error occurred, read error message and raise it | ||
elif self._read_status == -1: | ||
error_msg = UTF8Deserializer().loads(self._sockfile) | ||
raise RuntimeError("An error occurred while reading the next element from " | ||
"toLocalIterator: {}".format(error_msg)) | ||
|
||
def __del__(self): | ||
# If local iterator is not fully consumed, | ||
if self._read_status == 1: | ||
try: | ||
# Finish consuming partition data stream | ||
for _ in self._read_iter: | ||
pass | ||
# Tell Java to stop sending data and close connection | ||
write_int(0, self._sockfile) | ||
self._sockfile.flush() | ||
except Exception: | ||
# Ignore any errors, socket is automatically closed when garbage-collected | ||
pass | ||
|
||
return iter(PyLocalIterable(sock_info, serializer)) | ||
|
||
|
||
def ignore_unicode_prefix(f): | ||
""" | ||
Ignore the 'u' prefix of string in doc tests, to make it works | ||
|
@@ -2386,7 +2440,7 @@ def toLocalIterator(self): | |
""" | ||
with SCCallSiteSync(self.context) as css: | ||
sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) | ||
return _load_from_socket(sock_info, self._jrdd_deserializer) | ||
return _local_iterator_from_socket(sock_info, self._jrdd_deserializer) | ||
|
||
def barrier(self): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -677,6 +677,34 @@ def test_repr_behaviors(self): | |
self.assertEquals(None, df._repr_html_()) | ||
self.assertEquals(expected, df.__repr__()) | ||
|
||
def test_to_local_iterator(self): | ||
df = self.spark.range(8, numPartitions=4) | ||
expected = df.collect() | ||
it = df.toLocalIterator() | ||
self.assertEqual(expected, list(it)) | ||
|
||
# Test DataFrame with empty partition | ||
df = self.spark.range(3, numPartitions=4) | ||
it = df.toLocalIterator() | ||
expected = df.collect() | ||
self.assertEqual(expected, list(it)) | ||
|
||
def test_to_local_iterator_not_fully_consumed(self): | ||
# SPARK-23961: toLocalIterator throws exception when not fully consumed | ||
# Create a DataFrame large enough so that write to socket will eventually block | ||
df = self.spark.range(1 << 20, numPartitions=2) | ||
it = df.toLocalIterator() | ||
self.assertEqual(df.take(1)[0], next(it)) | ||
with QuietTest(self.sc): | ||
it = None # remove iterator from scope, socket is closed when cleaned up | ||
# Make sure normal df operations still work | ||
result = [] | ||
for i, row in enumerate(df.toLocalIterator()): | ||
result.append(row) | ||
if i == 7: | ||
break | ||
self.assertEqual(df.take(8), result) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: this would not have crashed before, only generated the error and I don't think it's possible to check if this error happened There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the error a JVM error? if so we could grab the stderr/stdout and look for the error message in the result there? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that might work. I can give it a try.. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did anything come of this? It's optional so don't block on it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh right, I forgot about this. Let me give it a shot now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It ends up being a little complicated, maybe better to try as a followup |
||
|
||
|
||
class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): | ||
# These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,15 +60,12 @@ def test_sum(self): | |
self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) | ||
|
||
def test_to_localiterator(self): | ||
from time import sleep | ||
rdd = self.sc.parallelize([1, 2, 3]) | ||
it = rdd.toLocalIterator() | ||
sleep(5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sleep is unnecessary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like less sleeps in the code <3 |
||
self.assertEqual([1, 2, 3], sorted(it)) | ||
|
||
rdd2 = rdd.repartition(1000) | ||
it2 = rdd2.toLocalIterator() | ||
sleep(5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
self.assertEqual([1, 2, 3], sorted(it2)) | ||
|
||
def test_save_as_textfile_with_unicode(self): | ||
|
@@ -736,6 +733,34 @@ def test_overwritten_global_func(self): | |
global_func = lambda: "Yeah" | ||
self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Yeah") | ||
|
||
def test_to_local_iterator_failure(self): | ||
# SPARK-27548 toLocalIterator task failure not propagated to Python driver | ||
|
||
def fail(_): | ||
raise RuntimeError("local iterator error") | ||
|
||
rdd = self.sc.range(10).map(fail) | ||
|
||
with self.assertRaisesRegexp(Exception, "local iterator error"): | ||
for _ in rdd.toLocalIterator(): | ||
pass | ||
|
||
def test_to_local_iterator_collects_single_partition(self): | ||
# Test that partitions are not computed until requested by iteration | ||
|
||
def fail_last(x): | ||
if x == 9: | ||
raise RuntimeError("This should not be hit") | ||
return x | ||
|
||
rdd = self.sc.range(12, numSlices=4).map(fail_last) | ||
it = rdd.toLocalIterator() | ||
|
||
# Only consume first 4 elements from partitions 1 and 2, this should not collect the last | ||
# partition which would trigger the error | ||
for i in range(4): | ||
self.assertEqual(i, next(it)) | ||
|
||
|
||
if __name__ == "__main__": | ||
import unittest | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once the local iterator is out of scope in Python side, will remaining jobs still be triggered after at Scala side it can't write into the closed connection?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the remaining jobs are not triggered. The python iterator finishes consuming the data from the current job, then sends a command for Scala iterator to stop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about previous behavior? The behavior before will trigger them? Looks like
toLocalIterator
won't trigger the job if we don't iterate the data on a partition.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous behavior was that the Scala local iterator would advance as long as the
write
calls to the socket are not blocked. So this means when Python reads a batch (auto-batched elements) from the current partition, this will unblock the Scala call to write and could start a job to collect the next partition.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once the local iterator at Python side is out of scope and so the iterator is not fully consumed, will it block the write call at Scala? Seems to me that it will and we shouldn't see unneeded jobs to be triggered after that, doesn't?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous behavior is when the iterator goes out of scope, the socket eventually is closed. This creates the error on the Scala side and the writing thread is terminated, so no more jobs are triggered but the user sees this error.