Skip to content

Commit 034cb13

Browse files
dvogelbacherHyukjinKwon
authored andcommitted
[SPARK-27778][PYTHON] Fix toPandas conversion of empty DataFrame with Arrow enabled
## What changes were proposed in this pull request? #22275 introduced a performance improvement where we send partitions out of order to python and then, as a last step, send the partition order as well. However, if there are no partitions we will never send the partition order and we will get an "EofError" on the python side. This PR fixes this by also sending the partition order if there are no partitions present. ## How was this patch tested? New unit test added. Closes #24650 from dvogelbacher/dv/fixNoPartitionArrowConversion. Authored-by: David Vogelbacher <dvogelbacher@palantir.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
1 parent 03c9e8a commit 034cb13

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

python/pyspark/sql/tests/test_arrow.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,14 @@ def test_filtered_frame(self):
183183
self.assertEqual(pdf.columns[0], "i")
184184
self.assertTrue(pdf.empty)
185185

186+
def test_no_partition_frame(self):
187+
schema = StructType([StructField("field1", StringType(), True)])
188+
df = self.spark.createDataFrame(self.sc.emptyRDD(), schema)
189+
pdf = df.toPandas()
190+
self.assertEqual(len(pdf.columns), 1)
191+
self.assertEqual(pdf.columns[0], "field1")
192+
self.assertTrue(pdf.empty)
193+
186194
def _createDataFrame_toggle(self, pdf, schema=None):
187195
with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
188196
df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3299,43 +3299,35 @@ class Dataset[T] private[sql](
32993299
PythonRDD.serveToStream("serve-Arrow") { outputStream =>
33003300
val out = new DataOutputStream(outputStream)
33013301
val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
3302-
val arrowBatchRdd = toArrowBatchRdd(plan)
3303-
val numPartitions = arrowBatchRdd.partitions.length
33043302

33053303
// Batches ordered by (index of partition, batch index in that partition) tuple
33063304
val batchOrder = ArrayBuffer.empty[(Int, Int)]
3307-
var partitionCount = 0
33083305

33093306
// Handler to eagerly write batches to Python as they arrive, un-ordered
3310-
def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = {
3307+
val handlePartitionBatches = (index: Int, arrowBatches: Array[Array[Byte]]) =>
33113308
if (arrowBatches.nonEmpty) {
33123309
// Write all batches (can be more than 1) in the partition, store the batch order tuple
33133310
batchWriter.writeBatches(arrowBatches.iterator)
33143311
arrowBatches.indices.foreach {
33153312
partitionBatchIndex => batchOrder.append((index, partitionBatchIndex))
33163313
}
33173314
}
3318-
partitionCount += 1
3319-
3320-
// After last batch, end the stream and write batch order indices
3321-
if (partitionCount == numPartitions) {
3322-
batchWriter.end()
3323-
out.writeInt(batchOrder.length)
3324-
// Sort by (index of partition, batch index in that partition) tuple to get the
3325-
// overall_batch_index from 0 to N-1 batches, which can be used to put the
3326-
// transferred batches in the correct order
3327-
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) =>
3328-
out.writeInt(overallBatchIndex)
3329-
}
3330-
out.flush()
3331-
}
3332-
}
33333315

3316+
val arrowBatchRdd = toArrowBatchRdd(plan)
33343317
sparkSession.sparkContext.runJob(
33353318
arrowBatchRdd,
3336-
(ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
3337-
0 until numPartitions,
3319+
(it: Iterator[Array[Byte]]) => it.toArray,
33383320
handlePartitionBatches)
3321+
3322+
// After processing all partitions, end the stream and write batch order indices
3323+
batchWriter.end()
3324+
out.writeInt(batchOrder.length)
3325+
// Sort by (index of partition, batch index in that partition) tuple to get the
3326+
// overall_batch_index from 0 to N-1 batches, which can be used to put the
3327+
// transferred batches in the correct order
3328+
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) =>
3329+
out.writeInt(overallBatchIndex)
3330+
}
33393331
}
33403332
}
33413333
}

0 commit comments

Comments
 (0)