Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARROW] Arrow serialization should not introduce extra shuffle for outermost limit #4662

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a584943
arrow take
cfmcgrady Mar 23, 2023
8593d85
driver slice last batch
cfmcgrady Mar 24, 2023
0088671
refine
cfmcgrady Mar 29, 2023
ed8c692
refactor
cfmcgrady Apr 3, 2023
4212a89
refactor and add ut
cfmcgrady Apr 4, 2023
6c5b1eb
add ut
cfmcgrady Apr 4, 2023
ee5a756
revert unnecessarily changes
cfmcgrady Apr 4, 2023
4e7ca54
unnecessarily changes
cfmcgrady Apr 4, 2023
885cf2c
infer row size by schema.defaultSize
cfmcgrady Apr 4, 2023
25e4f05
add docs
cfmcgrady Apr 4, 2023
03d0747
address comment
cfmcgrady Apr 6, 2023
2286afc
reflective calla AdaptiveSparkPlanExec.finalPhysicalPlan
cfmcgrady Apr 6, 2023
81886f0
address comment
cfmcgrady Apr 6, 2023
e3bf84c
refactor
cfmcgrady Apr 6, 2023
d70aee3
SparkPlan.session -> SparkSession.active to adapt Spark-3.1.x
cfmcgrady Apr 6, 2023
4cef204
SparkArrowbasedOperationSuite adapt Spark-3.1.x
cfmcgrady Apr 6, 2023
573a262
fix
cfmcgrady Apr 6, 2023
c83cf3f
SparkArrowbasedOperationSuite adapt Spark-3.1.x
cfmcgrady Apr 6, 2023
9ffb44f
make toBatchIterator private
cfmcgrady Apr 6, 2023
b72bc6f
add offset support to adapt Spark-3.4.x
cfmcgrady Apr 6, 2023
22cc70f
add ut
cfmcgrady Apr 6, 2023
8280783
add `isStaticConfigKey` to adapt Spark-3.1.x
cfmcgrady Apr 7, 2023
6d596fc
address comment
cfmcgrady Apr 7, 2023
6064ab9
limit = 0 test case
cfmcgrady Apr 7, 2023
3700839
SparkArrowbasedOperationSuite adapt Spark-3.1.x
cfmcgrady Apr 7, 2023
facc13f
exclude rule OptimizeLimitZero
cfmcgrady Apr 7, 2023
130bcb1
finally close
cfmcgrady Apr 7, 2023
82c912e
close vector
cfmcgrady Apr 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ import java.util.concurrent.RejectedExecutionException

import scala.collection.JavaConverters._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.kyuubi.SparkDatasetHelper
import org.apache.spark.sql.kyuubi.SparkDatasetHelper._
import org.apache.spark.sql.types._

import org.apache.kyuubi.{KyuubiSQLException, Logging}
Expand Down Expand Up @@ -187,42 +185,22 @@ class ArrowBasedExecuteStatement(
handle) {

override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
collectAsArrow(convertComplexType(resultDF)) { rdd =>
rdd.toLocalIterator
}
toArrowBatchLocalIterator(convertComplexType(resultDF))
}

override protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
collectAsArrow(convertComplexType(resultDF)) { rdd =>
rdd.collect()
}
executeCollect(convertComplexType(resultDF))
}

override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
// this will introduce shuffle and hurt performance
val limitedResult = resultDF.limit(maxRows)
collectAsArrow(convertComplexType(limitedResult)) { rdd =>
rdd.collect()
}
}

/**
* refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based
* operation, so that we can track the arrow-based queries on the UI tab.
*/
private def collectAsArrow[T](df: DataFrame)(action: RDD[Array[Byte]] => T): T = {
SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) {
df.queryExecution.executedPlan.resetMetrics()
action(SparkDatasetHelper.toArrowBatchRdd(df))
}
executeCollect(convertComplexType(resultDF.limit(maxRows)))
}

override protected def isArrowBasedOperation: Boolean = true

override val resultFormat = "arrow"

private def convertComplexType(df: DataFrame): DataFrame = {
SparkDatasetHelper.convertTopLevelComplexTypeToHiveString(df, timestampAsString)
convertTopLevelComplexTypeToHiveString(df, timestampAsString)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.arrow

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.execution.CollectLimitExec

object ArrowCollectUtils extends SQLConfHelper {

type Batch = (Array[Byte], Long)

/**
* Forked from `org.apache.spark.sql.execution.SparkPlan#executeTake()`, the algorithm can be
* summarized in the following steps:
* 1. If the limit specified in the CollectLimitExec object is 0, the function returns an empty
* array of batches.
* 2. Otherwise, execute the child query plan of the CollectLimitExec object to obtain an RDD of
* data to collect.
* 3. Use an iterative approach to collect data in batches until the specified limit is reached.
* In each iteration, it selects a subset of the partitions of the RDD to scan and tries to
* collect data from them.
* 4. For each partition subset, we use the runJob method of the Spark context to execute a
* closure that scans the partition data and converts it to Arrow batches.
* 5. Check if the collected data reaches the specified limit. If not, it selects another subset
* of partitions to scan and repeats the process until the limit is reached or all partitions
* have been scanned.
* 6. Return an array of all the collected Arrow batches.
*
* Note that:
* 1. The returned Arrow batches row count >= limit, if the input df has more than the `limit`
* row count
* 2. We don't implement the `takeFromEnd` logical
*
* @return
*/
def takeAsArrowBatches(
collectLimitExec: CollectLimitExec,
maxRecordsPerBatch: Long,
maxEstimatedBatchSize: Long,
timeZoneId: String): Array[Batch] = {
val n = collectLimitExec.limit
val schema = collectLimitExec.schema
if (n == 0) {
return new Array[Batch](0)
} else {
val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2)
// TODO: refactor and reuse the code from RDD's take()
val childRDD = collectLimitExec.child.execute()
val buf = new ArrayBuffer[Batch]
var bufferedRowSize = 0L
val totalParts = childRDD.partitions.length
var partsScanned = 0
while (bufferedRowSize < n && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = limitInitialNumPartitions
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, multiply by
// limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need
// to try, but overestimate it by 50%. We also cap the estimation in the end.
if (buf.isEmpty) {
numPartsToTry = partsScanned * limitScaleUpFactor
} else {
val left = n - bufferedRowSize
// As left > 0, numPartsToTry is always >= 1
numPartsToTry = Math.ceil(1.5 * left * partsScanned / bufferedRowSize).toInt
numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor)
}
}

val partsToScan =
partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)

val sc = collectLimitExec.session.sparkContext
val res = sc.runJob(
childRDD,
(it: Iterator[InternalRow]) => {
val batches = ArrowConvertersHelper.toBatchIterator(
it,
schema,
maxRecordsPerBatch,
maxEstimatedBatchSize,
n,
timeZoneId)
batches.map(b => b -> batches.rowCountInLastBatch).toArray
},
partsToScan)

var i = 0
while (bufferedRowSize < n && i < res.length) {
var j = 0
val batches = res(i)
while (j < batches.length && n > bufferedRowSize) {
val batch = batches(j)
val (_, batchSize) = batch
buf += batch
bufferedRowSize += batchSize
j += 1
}
i += 1
}
partsScanned += partsToScan.size
}

buf.toArray
}
}

/**
* Spark introduced the config `spark.sql.limit.initialNumPartitions` since 3.4.0. see SPARK-40211
*/
def limitInitialNumPartitions: Int = {
conf.getConfString("spark.sql.limit.initialNumPartitions", "1")
.toInt
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.arrow

import java.io.ByteArrayOutputStream
import java.nio.channels.Channels

import org.apache.arrow.vector._
import org.apache.arrow.vector.ipc.{ArrowStreamWriter, WriteChannel}
import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer}
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils

object ArrowConvertersHelper extends Logging {

/**
* Different from [[org.apache.spark.sql.execution.arrow.ArrowConvertersHelper.toBatchIterator]],
* each output arrow batch contains this batch row count.
*/
def toBatchIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Long,
maxEstimatedBatchSize: Long,
limit: Long,
timeZoneId: String): ArrowBatchIterator = {
new ArrowBatchIterator(
rowIter,
schema,
maxRecordsPerBatch,
maxEstimatedBatchSize,
limit,
timeZoneId,
TaskContext.get)
}

private[sql] class ArrowBatchIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Long,
maxEstimatedBatchSize: Long,
limit: Long,
timeZoneId: String,
context: TaskContext)
extends Iterator[Array[Byte]] {

protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
private val allocator =
ArrowUtils.rootAllocator.newChildAllocator(
s"to${this.getClass.getSimpleName}",
0,
Long.MaxValue)

private val root = VectorSchemaRoot.create(arrowSchema, allocator)
protected val unloader = new VectorUnloader(root)
protected val arrowWriter = ArrowWriter.create(root)

Option(context).foreach {
_.addTaskCompletionListener[Unit] { _ =>
root.close()
allocator.close()
}
}

override def hasNext: Boolean = (rowIter.hasNext && rowCount < limit) || {
root.close()
allocator.close()
false
}

var rowCountInLastBatch: Long = 0
var rowCount: Long = 0

override def next(): Array[Byte] = {
val out = new ByteArrayOutputStream()
val writeChannel = new WriteChannel(Channels.newChannel(out))

rowCountInLastBatch = 0
var estimatedBatchSize = 0
cfmcgrady marked this conversation as resolved.
Show resolved Hide resolved
Utils.tryWithSafeFinally {

// Always write the first row.
while (rowIter.hasNext && (
// For maxBatchSize and maxRecordsPerBatch, respect whatever smaller.
// If the size in bytes is positive (set properly), always write the first row.
rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 ||
// If the size in bytes of rows are 0 or negative, unlimit it.
estimatedBatchSize <= 0 ||
estimatedBatchSize < maxEstimatedBatchSize ||
// If the size of rows are 0 or negative, unlimit it.
maxRecordsPerBatch <= 0 ||
rowCountInLastBatch < maxRecordsPerBatch ||
rowCount < limit)) {
val row = rowIter.next()
arrowWriter.write(row)
estimatedBatchSize += (row match {
case ur: UnsafeRow => ur.getSizeInBytes
// Trying to estimate the size of the current row, assuming 16 bytes per value.
case ir: InternalRow => ir.numFields * 16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general, we can infer row size by schema.defaultSize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the lack of documentation.
This class ArrowBatchIterator is derived from org.apache.spark.sql.execution.arrow.ArrowConverters.ArrowBatchWithSchemaIterator, with two key differences:

  1. there is no requirement to write the schema at the batch header
  2. iteration halts when rowCount equals limit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is the diff, compare with latest spark master branch https://github.com/apache/spark/blob/3c189abd73afa998e8573cbfdaf0f72445284314/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala

-  private[sql] class ArrowBatchWithSchemaIterator(
+  private[sql] class ArrowBatchIterator(
       rowIter: Iterator[InternalRow],
       schema: StructType,
       maxRecordsPerBatch: Long,
       maxEstimatedBatchSize: Long,
+      limit: Long,
       timeZoneId: String,
       context: TaskContext)
-    extends ArrowBatchIterator(
-      rowIter, schema, maxRecordsPerBatch, timeZoneId, context) {
+    extends Iterator[Array[Byte]] {
+

-    private val arrowSchemaSize = SizeEstimator.estimate(arrowSchema)
     var rowCountInLastBatch: Long = 0
+    var rowCount: Long = 0

     override def next(): Array[Byte] = {
       val out = new ByteArrayOutputStream()
       val writeChannel = new WriteChannel(Channels.newChannel(out))

       rowCountInLastBatch = 0
-      var estimatedBatchSize = arrowSchemaSize
+      var estimatedBatchSize = 0
       Utils.tryWithSafeFinally {
-        // Always write the schema.
-        MessageSerializer.serialize(writeChannel, arrowSchema)

         // Always write the first row.
         while (rowIter.hasNext && (
@@ -31,15 +30,17 @@
             estimatedBatchSize < maxEstimatedBatchSize ||
             // If the size of rows are 0 or negative, unlimit it.
             maxRecordsPerBatch <= 0 ||
-            rowCountInLastBatch < maxRecordsPerBatch)) {
+            rowCountInLastBatch < maxRecordsPerBatch ||
+            rowCount < limit)) {
           val row = rowIter.next()
           arrowWriter.write(row)
           estimatedBatchSize += (row match {
             case ur: UnsafeRow => ur.getSizeInBytes
-            // Trying to estimate the size of the current row, assuming 16 bytes per value.
-            case ir: InternalRow => ir.numFields * 16
+            // Trying to estimate the size of the current row
+            case _: InternalRow => schema.defaultSize
           })
           rowCountInLastBatch += 1
+          rowCount += 1
         }
         arrowWriter.finish()
         val batch = unloader.getRecordBatch()

})
rowCountInLastBatch += 1
rowCount += 1
}
arrowWriter.finish()
val batch = unloader.getRecordBatch()
MessageSerializer.serialize(writeChannel, batch)

// Always write the Ipc options at the end.
ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)

batch.close()
} {
arrowWriter.reset()
}

out.toByteArray
}
}

// for testing
def fromBatchIterator(
arrowBatchIter: Iterator[Array[Byte]],
schema: StructType,
timeZoneId: String,
context: TaskContext): Iterator[InternalRow] = {
ArrowConverters.fromBatchIterator(arrowBatchIter, schema, timeZoneId, context)
}
}
Loading