Skip to content

Commit 82c18c2

Browse files
BryanCutlerHyukjinKwon
authored andcommitted
[SPARK-23030][SQL][PYTHON] Use Arrow stream format for creating from and collecting Pandas DataFrames
## What changes were proposed in this pull request? This changes the calls of `toPandas()` and `createDataFrame()` to use the Arrow stream format, when Arrow is enabled. Previously, Arrow data was written to byte arrays where each chunk is an output of the Arrow file format. This was mainly due to constraints at the time, and caused some overhead by writing the schema/footer on each chunk of data and then having to read multiple Arrow file inputs and concat them together. Using the Arrow stream format has improved these by increasing performance, lower memory overhead for the average case, and simplified the code. Here are the details of this change: **toPandas()** _Before:_ Spark internal rows are converted to Arrow file format, each group of records is a complete Arrow file which contains the schema and other metadata. Next a collect is done and an Array of Arrow files is the result. After that each Arrow file is sent to Python driver which then loads each file and concats them to a single Arrow DataFrame. _After:_ Spark internal rows are converted to ArrowRecordBatches directly, which is the simplest Arrow component for IPC data transfers. The driver JVM then immediately starts serving data to Python as an Arrow stream, sending the schema first. It then starts a Spark job with a custom handler that sends Arrow RecordBatches to Python. Partitions arriving in order are sent immediately, and out-of-order partitions are buffered until the ones that precede it come in. This improves performance, simplifies memory usage on executors, and improves the average memory usage on the JVM driver. Since the order of partitions must be preserved, the worst case is that the first partition will be the last to arrive all data must be buffered in memory until then. This case is no worse that before when doing a full collect. **createDataFrame()** _Before:_ A Pandas DataFrame is split into parts and each part is made into an Arrow file. Then each file is prefixed by the buffer size and written to a temp file. The temp file is read and each Arrow file is parallelized as a byte array. _After:_ A Pandas DataFrame is split into parts, then an Arrow stream is written to a temp file where each part is an ArrowRecordBatch. The temp file is read as a stream and the Arrow messages are examined. If the message is an ArrowRecordBatch, the data is saved as a byte array. After reading the file, each ArrowRecordBatch is parallelized as a byte array. This has slightly more processing than before because we must look each Arrow message to extract the record batches, but performance ends up a litle better. It is cleaner in the sense that IPC from Python to JVM is done over a single Arrow stream. ## How was this patch tested? Added new unit tests for the additions to ArrowConverters in Scala, existing tests for Python. ## Performance Tests - toPandas Tests run on a 4 node standalone cluster with 32 cores total, 14.04.1-Ubuntu and OpenJDK 8 measured wall clock time to execute `toPandas()` and took the average best time of 5 runs/5 loops each. Test code ```python df = spark.range(1 << 25, numPartitions=32).toDF("id").withColumn("x1", rand()).withColumn("x2", rand()).withColumn("x3", rand()).withColumn("x4", rand()) for i in range(5): start = time.time() _ = df.toPandas() elapsed = time.time() - start ``` Current Master | This PR ---------------------|------------ 5.803557 | 5.16207 5.409119 | 5.133671 5.493509 | 5.147513 5.433107 | 5.105243 5.488757 | 5.018685 Avg Master | Avg This PR ------------------|-------------- 5.5256098 | 5.1134364 Speedup of **1.08060595** ## Performance Tests - createDataFrame Tests run on a 4 node standalone cluster with 32 cores total, 14.04.1-Ubuntu and OpenJDK 8 measured wall clock time to execute `createDataFrame()` and get the first record. Took the average best time of 5 runs/5 loops each. Test code ```python def run(): pdf = pd.DataFrame(np.random.rand(10000000, 10)) spark.createDataFrame(pdf).first() for i in range(6): start = time.time() run() elapsed = time.time() - start gc.collect() print("Run %d: %f" % (i, elapsed)) ``` Current Master | This PR --------------------|---------- 6.234608 | 5.665641 6.32144 | 5.3475 6.527859 | 5.370803 6.95089 | 5.479151 6.235046 | 5.529167 Avg Master | Avg This PR ---------------|---------------- 6.4539686 | 5.4784524 Speedup of **1.178064192** ## Memory Improvements **toPandas()** The most significant improvement is reduction of the upper bound space complexity in the JVM driver. Before, the entire dataset was collected in the JVM first before sending it to Python. With this change, as soon as a partition is collected, the result handler immediately sends it to Python, so the upper bound is the size of the largest partition. Also, using the Arrow stream format is more efficient because the schema is written once per stream, followed by record batches. The schema is now only send from driver JVM to Python. Before, multiple Arrow file formats were used that each contained the schema. This duplicated schema was created in the executors, sent to the driver JVM, and then Python where all but the first one received are discarded. I verified the upper bound limit by running a test that would collect data that would exceed the amount of driver JVM memory available. Using these settings on a standalone cluster: ``` spark.driver.memory 1g spark.executor.memory 5g spark.sql.execution.arrow.enabled true spark.sql.execution.arrow.fallback.enabled false spark.sql.execution.arrow.maxRecordsPerBatch 0 spark.driver.maxResultSize 2g ``` Test code: ```python from pyspark.sql.functions import rand df = spark.range(1 << 25, numPartitions=32).toDF("id").withColumn("x1", rand()).withColumn("x2", rand()).withColumn("x3", rand()) df.toPandas() ``` This makes total data size of 33554432×8×4 = 1073741824 With the current master, it fails with OOM but passes using this PR. **createDataFrame()** No significant change in memory except that using the stream format instead of separate file formats avoids duplicated the schema, similar to toPandas above. The process of reading the stream and parallelizing the batches does cause the record batch message metadata to be copied, but it's size is insignificant. Closes #21546 from BryanCutler/arrow-toPandas-stream-SPARK-23030. Authored-by: Bryan Cutler <cutlerb@gmail.com> Signed-off-by: hyukjinkwon <gurwls223@apache.org>
1 parent ff8dcc1 commit 82c18c2

File tree

9 files changed

+326
-163
lines changed

9 files changed

+326
-163
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,26 @@ private[spark] object PythonRDD extends Logging {
399399
* data collected from this job, and the secret for authentication.
400400
*/
401401
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
402+
serveToStream(threadName) { out =>
403+
writeIteratorToStream(items, new DataOutputStream(out))
404+
}
405+
}
406+
407+
/**
408+
* Create a socket server and background thread to execute the writeFunc
409+
* with the given OutputStream.
410+
*
411+
* The socket server can only accept one connection, or close if no connection
412+
* in 15 seconds.
413+
*
414+
* Once a connection comes in, it will execute the block of code and pass in
415+
* the socket output stream.
416+
*
417+
* The thread will terminate after the block of code is executed or any
418+
* exceptions happen.
419+
*/
420+
private[spark] def serveToStream(
421+
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
402422
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
403423
// Close the socket if no connection in 15 seconds
404424
serverSocket.setSoTimeout(15000)
@@ -410,9 +430,9 @@ private[spark] object PythonRDD extends Logging {
410430
val sock = serverSocket.accept()
411431
authHelper.authClient(sock)
412432

413-
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
433+
val out = new BufferedOutputStream(sock.getOutputStream)
414434
Utils.tryWithSafeFinally {
415-
writeIteratorToStream(items, out)
435+
writeFunc(out)
416436
} {
417437
out.close()
418438
sock.close()

python/pyspark/context.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -494,10 +494,14 @@ def f(split, iterator):
494494
c = list(c) # Make it a list so we can compute its length
495495
batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
496496
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
497-
jrdd = self._serialize_to_jvm(c, numSlices, serializer)
497+
498+
def reader_func(temp_filename):
499+
return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices)
500+
501+
jrdd = self._serialize_to_jvm(c, serializer, reader_func)
498502
return RDD(jrdd, self, serializer)
499503

500-
def _serialize_to_jvm(self, data, parallelism, serializer):
504+
def _serialize_to_jvm(self, data, serializer, reader_func):
501505
"""
502506
Calling the Java parallelize() method with an ArrayList is too slow,
503507
because it sends O(n) Py4J commands. As an alternative, serialized
@@ -507,8 +511,7 @@ def _serialize_to_jvm(self, data, parallelism, serializer):
507511
try:
508512
serializer.dump_stream(data, tempFile)
509513
tempFile.close()
510-
readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
511-
return readRDDFromFile(self._jsc, tempFile.name, parallelism)
514+
return reader_func(tempFile.name)
512515
finally:
513516
# readRDDFromFile eagerily reads the file so we can delete right after.
514517
os.unlink(tempFile.name)

python/pyspark/serializers.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -185,27 +185,31 @@ def loads(self, obj):
185185
raise NotImplementedError
186186

187187

188-
class ArrowSerializer(FramedSerializer):
188+
class ArrowStreamSerializer(Serializer):
189189
"""
190-
Serializes bytes as Arrow data with the Arrow file format.
190+
Serializes Arrow record batches as a stream.
191191
"""
192192

193-
def dumps(self, batch):
193+
def dump_stream(self, iterator, stream):
194194
import pyarrow as pa
195-
import io
196-
sink = io.BytesIO()
197-
writer = pa.RecordBatchFileWriter(sink, batch.schema)
198-
writer.write_batch(batch)
199-
writer.close()
200-
return sink.getvalue()
195+
writer = None
196+
try:
197+
for batch in iterator:
198+
if writer is None:
199+
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
200+
writer.write_batch(batch)
201+
finally:
202+
if writer is not None:
203+
writer.close()
201204

202-
def loads(self, obj):
205+
def load_stream(self, stream):
203206
import pyarrow as pa
204-
reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
205-
return reader.read_all()
207+
reader = pa.open_stream(stream)
208+
for batch in reader:
209+
yield batch
206210

207211
def __repr__(self):
208-
return "ArrowSerializer"
212+
return "ArrowStreamSerializer"
209213

210214

211215
def _create_batch(series, timezone):

python/pyspark/sql/dataframe.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from pyspark import copy_func, since, _NoValue
3131
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
32-
from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \
32+
from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \
3333
UTF8Deserializer
3434
from pyspark.storagelevel import StorageLevel
3535
from pyspark.traceback_utils import SCCallSiteSync
@@ -2118,10 +2118,9 @@ def toPandas(self):
21182118
from pyspark.sql.types import _check_dataframe_convert_date, \
21192119
_check_dataframe_localize_timestamps
21202120
import pyarrow
2121-
2122-
tables = self._collectAsArrow()
2123-
if tables:
2124-
table = pyarrow.concat_tables(tables)
2121+
batches = self._collectAsArrow()
2122+
if len(batches) > 0:
2123+
table = pyarrow.Table.from_batches(batches)
21252124
pdf = table.to_pandas()
21262125
pdf = _check_dataframe_convert_date(pdf, self.schema)
21272126
return _check_dataframe_localize_timestamps(pdf, timezone)
@@ -2170,14 +2169,14 @@ def toPandas(self):
21702169

21712170
def _collectAsArrow(self):
21722171
"""
2173-
Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed
2174-
and available.
2172+
Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
2173+
and available on driver and worker Python environments.
21752174
21762175
.. note:: Experimental.
21772176
"""
21782177
with SCCallSiteSync(self._sc) as css:
21792178
sock_info = self._jdf.collectAsArrowToPython()
2180-
return list(_load_from_socket(sock_info, ArrowSerializer()))
2179+
return list(_load_from_socket(sock_info, ArrowStreamSerializer()))
21812180

21822181
##########################################################################################
21832182
# Pandas compatibility

python/pyspark/sql/session.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
501501
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
502502
data types will be used to coerce the data in Pandas to Arrow conversion.
503503
"""
504-
from pyspark.serializers import ArrowSerializer, _create_batch
504+
from pyspark.serializers import ArrowStreamSerializer, _create_batch
505505
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
506506
from pyspark.sql.utils import require_minimum_pandas_version, \
507507
require_minimum_pyarrow_version
@@ -539,10 +539,12 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
539539
struct.names[i] = name
540540
schema = struct
541541

542-
# Create the Spark DataFrame directly from the Arrow data and schema
543-
jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer())
544-
jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
545-
jrdd, schema.json(), self._wrapped._jsqlContext)
542+
def reader_func(temp_filename):
543+
return self._jvm.PythonSQLUtils.arrowReadStreamFromFile(
544+
self._wrapped._jsqlContext, temp_filename, schema.json())
545+
546+
# Create Spark DataFrame from Arrow stream file, using one batch per partition
547+
jdf = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func)
546548
df = DataFrame(jdf, self._wrapped)
547549
df._schema = schema
548550
return df

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
4848
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
4949
import org.apache.spark.sql.catalyst.util.DateTimeUtils
5050
import org.apache.spark.sql.execution._
51-
import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload}
51+
import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters}
5252
import org.apache.spark.sql.execution.command._
5353
import org.apache.spark.sql.execution.datasources.LogicalRelation
5454
import org.apache.spark.sql.execution.python.EvaluatePython
@@ -3273,13 +3273,49 @@ class Dataset[T] private[sql](
32733273
}
32743274

32753275
/**
3276-
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
3276+
* Collect a Dataset as Arrow batches and serve stream to PySpark.
32773277
*/
32783278
private[sql] def collectAsArrowToPython(): Array[Any] = {
3279+
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
3280+
32793281
withAction("collectAsArrowToPython", queryExecution) { plan =>
3280-
val iter: Iterator[Array[Byte]] =
3281-
toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
3282-
PythonRDD.serveIterator(iter, "serve-Arrow")
3282+
PythonRDD.serveToStream("serve-Arrow") { out =>
3283+
val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
3284+
val arrowBatchRdd = toArrowBatchRdd(plan)
3285+
val numPartitions = arrowBatchRdd.partitions.length
3286+
3287+
// Store collection results for worst case of 1 to N-1 partitions
3288+
val results = new Array[Array[Array[Byte]]](numPartitions - 1)
3289+
var lastIndex = -1 // index of last partition written
3290+
3291+
// Handler to eagerly write partitions to Python in order
3292+
def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = {
3293+
// If result is from next partition in order
3294+
if (index - 1 == lastIndex) {
3295+
batchWriter.writeBatches(arrowBatches.iterator)
3296+
lastIndex += 1
3297+
// Write stored partitions that come next in order
3298+
while (lastIndex < results.length && results(lastIndex) != null) {
3299+
batchWriter.writeBatches(results(lastIndex).iterator)
3300+
results(lastIndex) = null
3301+
lastIndex += 1
3302+
}
3303+
// After last batch, end the stream
3304+
if (lastIndex == results.length) {
3305+
batchWriter.end()
3306+
}
3307+
} else {
3308+
// Store partitions received out of order
3309+
results(index - 1) = arrowBatches
3310+
}
3311+
}
3312+
3313+
sparkSession.sparkContext.runJob(
3314+
arrowBatchRdd,
3315+
(ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
3316+
0 until numPartitions,
3317+
handlePartitionBatches)
3318+
}
32833319
}
32843320
}
32853321

@@ -3386,20 +3422,20 @@ class Dataset[T] private[sql](
33863422
}
33873423
}
33883424

3389-
/** Convert to an RDD of ArrowPayload byte arrays */
3390-
private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = {
3425+
/** Convert to an RDD of serialized ArrowRecordBatches. */
3426+
private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
33913427
val schemaCaptured = this.schema
33923428
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
33933429
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
33943430
plan.execute().mapPartitionsInternal { iter =>
33953431
val context = TaskContext.get()
3396-
ArrowConverters.toPayloadIterator(
3432+
ArrowConverters.toBatchIterator(
33973433
iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context)
33983434
}
33993435
}
34003436

34013437
// This is only used in tests, for now.
3402-
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
3403-
toArrowPayload(queryExecution.executedPlan)
3438+
private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = {
3439+
toArrowBatchRdd(queryExecution.executedPlan)
34043440
}
34053441
}

sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.sql.api.python
1919

20-
import org.apache.spark.api.java.JavaRDD
2120
import org.apache.spark.sql.{DataFrame, SQLContext}
2221
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
2322
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
@@ -34,17 +33,19 @@ private[sql] object PythonSQLUtils {
3433
}
3534

3635
/**
37-
* Python Callable function to convert ArrowPayloads into a [[DataFrame]].
36+
* Python callable function to read a file in Arrow stream format and create a [[DataFrame]]
37+
* using each serialized ArrowRecordBatch as a partition.
3838
*
39-
* @param payloadRDD A JavaRDD of ArrowPayloads.
40-
* @param schemaString JSON Formatted Schema for ArrowPayloads.
4139
* @param sqlContext The active [[SQLContext]].
42-
* @return The converted [[DataFrame]].
40+
* @param filename File to read the Arrow stream from.
41+
* @param schemaString JSON Formatted Spark schema for Arrow batches.
42+
* @return A new [[DataFrame]].
4343
*/
44-
def arrowPayloadToDataFrame(
45-
payloadRDD: JavaRDD[Array[Byte]],
46-
schemaString: String,
47-
sqlContext: SQLContext): DataFrame = {
48-
ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext)
44+
def arrowReadStreamFromFile(
45+
sqlContext: SQLContext,
46+
filename: String,
47+
schemaString: String): DataFrame = {
48+
val jrdd = ArrowConverters.readArrowStreamFromFile(sqlContext, filename)
49+
ArrowConverters.toDataFrame(jrdd, schemaString, sqlContext)
4950
}
5051
}

0 commit comments

Comments
 (0)