Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,63 @@ private[spark] object PythonRDD extends Logging {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}

/**
* A helper function to create a local RDD iterator and serve it via socket. Partitions are
* are collected as separate jobs, by order of index. Partition data is first requested by a
* non-zero integer to start a collection job. The response is prefaced by an integer with 1
* meaning partition data will be served, 0 meaning the local iterator has been consumed,
* and -1 meaining an error occurred during collection. This function is used by
* pyspark.rdd._local_iterator_from_socket().
*
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from these jobs, and the secret for authentication.
*/
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is also possible that this change would be very beneficial because if the iterator is not fully consumed, it could save the triggering of unneeded jobs where the behavior before eagerly queued jobs for all partitions. In this sense, the change here more closely follows the Scala behavior.

Once the local iterator is out of scope in Python side, will remaining jobs still be triggered after at Scala side it can't write into the closed connection?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the remaining jobs are not triggered. The python iterator finishes consuming the data from the current job, then sends a command for Scala iterator to stop.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about previous behavior? The behavior before will trigger them? Looks like toLocalIterator won't trigger the job if we don't iterate the data on a partition.

Copy link
Member Author

@BryanCutler BryanCutler Apr 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous behavior was that the Scala local iterator would advance as long as the write calls to the socket are not blocked. So this means when Python reads a batch (auto-batched elements) from the current partition, this will unblock the Scala call to write and could start a job to collect the next partition.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once the local iterator at Python side is out of scope and so the iterator is not fully consumed, will it block the write call at Scala? Seems to me that it will and we shouldn't see unneeded jobs to be triggered after that, doesn't?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous behavior is when the iterator goes out of scope, the socket eventually is closed. This creates the error on the Scala side and the writing thread is terminated, so no more jobs are triggered but the user sees this error.

val (port, secret) = SocketAuthServer.setupOneConnectionServer(
authHelper, "serve toLocalIterator") { s =>
val out = new DataOutputStream(s.getOutputStream)
val in = new DataInputStream(s.getInputStream)
Utils.tryWithSafeFinally {

// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this is the same function as collectPartition(p: Int) in Scala, except here we do not want to flatten the collected arrays

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For performance, as mentioned in my questions, would it make sense to use something like a iterator with look ahead of say 1 partition (or X% of partitions) so we decrease the blocking time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that would have better performance, but it does say in the doc that max memory usage will be the largest partition. Going over that might cause problems for some people, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point @BryanCutler, we could implement the lookahead as a seperate PR/JIRA and allow it to be turned-on/off. I'd suggest this PR is more about fixing the behaviour of toLocalIterator memory wise than the out-of-scope issue (although the out-of-scope issue is maybe more visible in the logs).

}

// Read request for data and send next partition if nonzero
var complete = false
while (!complete && in.readInt() != 0) {
if (collectPartitionIter.hasNext) {
try {
// Attempt to collect the next partition
val partitionArray = collectPartitionIter.next()

// Send response there is a partition to read
out.writeInt(1)

// Write the next object and signal end of data for this iteration
writeIteratorToStream(partitionArray.toIterator, out)
out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
out.flush()
} catch {
case e: SparkException =>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to catch any errors during the collection job, so I believe the SparkException should all that is needed here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's reasonable +1

// Send response that an error occurred followed by error message
out.writeInt(-1)
writeUTF(e.getMessage, out)
complete = true
}
} else {
// Send response there are no more partitions to read and close
out.writeInt(0)
complete = true
}
}
} {
out.close()
in.close()
}
}
Array(port, secret)
}

def readRDDFromFile(
Expand Down
66 changes: 60 additions & 6 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
from itertools import imap as map, ifilter as filter

from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long, AutoBatchedSerializer
from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, NoOpSerializer, \
CartesianDeserializer, CloudPickleSerializer, PairDeserializer, PickleSerializer, \
UTF8Deserializer, pack_long, read_int, write_int
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_full_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
Expand Down Expand Up @@ -138,15 +138,69 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])


def _load_from_socket(sock_info, serializer):
def _create_local_socket(sock_info):
(sockfile, sock) = local_connect_and_auth(*sock_info)
# The RDD materialization time is unpredicable, if we set a timeout for socket reading
# The RDD materialization time is unpredictable, if we set a timeout for socket reading
# operation, it will very possibly fail. See SPARK-18281.
sock.settimeout(None)
return sockfile


def _load_from_socket(sock_info, serializer):
sockfile = _create_local_socket(sock_info)
# The socket will be automatically closed when garbage-collected.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that this comment should be moved above in _create_local_socket, too.

return serializer.load_stream(sockfile)


def _local_iterator_from_socket(sock_info, serializer):

class PyLocalIterable(object):
""" Create a synchronous local iterable over a socket """

def __init__(self, _sock_info, _serializer):
self._sockfile = _create_local_socket(_sock_info)
self._serializer = _serializer
self._read_iter = iter([]) # Initialize as empty iterator
self._read_status = 1

def __iter__(self):
while self._read_status == 1:
# Request next partition data from Java
write_int(1, self._sockfile)
self._sockfile.flush()

# If response is 1 then there is a partition to read, if 0 then fully consumed
self._read_status = read_int(self._sockfile)
if self._read_status == 1:

# Load the partition data as a stream and read each item
self._read_iter = self._serializer.load_stream(self._sockfile)
for item in self._read_iter:
yield item

# An error occurred, read error message and raise it
elif self._read_status == -1:
error_msg = UTF8Deserializer().loads(self._sockfile)
raise RuntimeError("An error occurred while reading the next element from "
"toLocalIterator: {}".format(error_msg))

def __del__(self):
# If local iterator is not fully consumed,
if self._read_status == 1:
try:
# Finish consuming partition data stream
for _ in self._read_iter:
pass
# Tell Java to stop sending data and close connection
write_int(0, self._sockfile)
self._sockfile.flush()
except Exception:
# Ignore any errors, socket is automatically closed when garbage-collected
pass

return iter(PyLocalIterable(sock_info, serializer))


def ignore_unicode_prefix(f):
"""
Ignore the 'u' prefix of string in doc tests, to make it works
Expand Down Expand Up @@ -2386,7 +2440,7 @@ def toLocalIterator(self):
"""
with SCCallSiteSync(self.context) as css:
sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
return _load_from_socket(sock_info, self._jrdd_deserializer)
return _local_iterator_from_socket(sock_info, self._jrdd_deserializer)

def barrier(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import warnings

from pyspark import copy_func, since, _NoValue
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.rdd import RDD, _load_from_socket, _local_iterator_from_socket, ignore_unicode_prefix
from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \
UTF8Deserializer
from pyspark.storagelevel import StorageLevel
Expand Down Expand Up @@ -528,7 +528,7 @@ def toLocalIterator(self):
"""
with SCCallSiteSync(self._sc) as css:
sock_info = self._jdf.toPythonIterator()
return _load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))
return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer()))

@ignore_unicode_prefix
@since(1.3)
Expand Down
28 changes: 28 additions & 0 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,34 @@ def test_repr_behaviors(self):
self.assertEquals(None, df._repr_html_())
self.assertEquals(expected, df.__repr__())

def test_to_local_iterator(self):
df = self.spark.range(8, numPartitions=4)
expected = df.collect()
it = df.toLocalIterator()
self.assertEqual(expected, list(it))

# Test DataFrame with empty partition
df = self.spark.range(3, numPartitions=4)
it = df.toLocalIterator()
expected = df.collect()
self.assertEqual(expected, list(it))

def test_to_local_iterator_not_fully_consumed(self):
# SPARK-23961: toLocalIterator throws exception when not fully consumed
# Create a DataFrame large enough so that write to socket will eventually block
df = self.spark.range(1 << 20, numPartitions=2)
it = df.toLocalIterator()
self.assertEqual(df.take(1)[0], next(it))
with QuietTest(self.sc):
it = None # remove iterator from scope, socket is closed when cleaned up
# Make sure normal df operations still work
result = []
for i, row in enumerate(df.toLocalIterator()):
result.append(row)
if i == 7:
break
self.assertEqual(df.take(8), result)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this would not have crashed before, only generated the error and I don't think it's possible to check if this error happened

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the error a JVM error? if so we could grab the stderr/stdout and look for the error message in the result there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that might work. I can give it a try..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did anything come of this? It's optional so don't block on it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, I forgot about this. Let me give it a shot now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It ends up being a little complicated, maybe better to try as a followup



class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
# These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is
Expand Down
31 changes: 28 additions & 3 deletions python/pyspark/tests/test_rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,12 @@ def test_sum(self):
self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum())

def test_to_localiterator(self):
from time import sleep
rdd = self.sc.parallelize([1, 2, 3])
it = rdd.toLocalIterator()
sleep(5)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sleep is unnecessary. rdd.toLocalIterator makes the socket connection and iterating starts reading from the pyspark serializer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like less sleeps in the code <3

self.assertEqual([1, 2, 3], sorted(it))

rdd2 = rdd.repartition(1000)
it2 = rdd2.toLocalIterator()
sleep(5)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

self.assertEqual([1, 2, 3], sorted(it2))

def test_save_as_textfile_with_unicode(self):
Expand Down Expand Up @@ -736,6 +733,34 @@ def test_overwritten_global_func(self):
global_func = lambda: "Yeah"
self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Yeah")

def test_to_local_iterator_failure(self):
# SPARK-27548 toLocalIterator task failure not propagated to Python driver

def fail(_):
raise RuntimeError("local iterator error")

rdd = self.sc.range(10).map(fail)

with self.assertRaisesRegexp(Exception, "local iterator error"):
for _ in rdd.toLocalIterator():
pass

def test_to_local_iterator_collects_single_partition(self):
# Test that partitions are not computed until requested by iteration

def fail_last(x):
if x == 9:
raise RuntimeError("This should not be hit")
return x

rdd = self.sc.range(12, numSlices=4).map(fail_last)
it = rdd.toLocalIterator()

# Only consume first 4 elements from partitions 1 and 2, this should not collect the last
# partition which would trigger the error
for i in range(4):
self.assertEqual(i, next(it))


if __name__ == "__main__":
import unittest
Expand Down