Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARROW] Arrow serialization should not introduce extra shuffle for outermost limit #4662

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a584943
arrow take
cfmcgrady Mar 23, 2023
8593d85
driver slice last batch
cfmcgrady Mar 24, 2023
0088671
refine
cfmcgrady Mar 29, 2023
ed8c692
refactor
cfmcgrady Apr 3, 2023
4212a89
refactor and add ut
cfmcgrady Apr 4, 2023
6c5b1eb
add ut
cfmcgrady Apr 4, 2023
ee5a756
revert unnecessarily changes
cfmcgrady Apr 4, 2023
4e7ca54
unnecessarily changes
cfmcgrady Apr 4, 2023
885cf2c
infer row size by schema.defaultSize
cfmcgrady Apr 4, 2023
25e4f05
add docs
cfmcgrady Apr 4, 2023
03d0747
address comment
cfmcgrady Apr 6, 2023
2286afc
reflective calla AdaptiveSparkPlanExec.finalPhysicalPlan
cfmcgrady Apr 6, 2023
81886f0
address comment
cfmcgrady Apr 6, 2023
e3bf84c
refactor
cfmcgrady Apr 6, 2023
d70aee3
SparkPlan.session -> SparkSession.active to adapt Spark-3.1.x
cfmcgrady Apr 6, 2023
4cef204
SparkArrowbasedOperationSuite adapt Spark-3.1.x
cfmcgrady Apr 6, 2023
573a262
fix
cfmcgrady Apr 6, 2023
c83cf3f
SparkArrowbasedOperationSuite adapt Spark-3.1.x
cfmcgrady Apr 6, 2023
9ffb44f
make toBatchIterator private
cfmcgrady Apr 6, 2023
b72bc6f
add offset support to adapt Spark-3.4.x
cfmcgrady Apr 6, 2023
22cc70f
add ut
cfmcgrady Apr 6, 2023
8280783
add `isStaticConfigKey` to adapt Spark-3.1.x
cfmcgrady Apr 7, 2023
6d596fc
address comment
cfmcgrady Apr 7, 2023
6064ab9
limit = 0 test case
cfmcgrady Apr 7, 2023
3700839
SparkArrowbasedOperationSuite adapt Spark-3.1.x
cfmcgrady Apr 7, 2023
facc13f
exclude rule OptimizeLimitZero
cfmcgrady Apr 7, 2023
130bcb1
finally close
cfmcgrady Apr 7, 2023
82c912e
close vector
cfmcgrady Apr 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor
  • Loading branch information
cfmcgrady committed Apr 4, 2023
commit ed8c6928baeda334773a3067ac08a84666f5a463
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ import java.util.concurrent.RejectedExecutionException
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.network.util.{ByteUnit, JavaUtils}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.execution.{CollectLimitExec, SQLExecution}
import org.apache.spark.sql.execution.arrow.{ArrowCollectLimitExec, KyuubiArrowUtils}
import org.apache.spark.sql.execution.arrow.{ArrowCollectUtils, KyuubiArrowUtils}
import org.apache.spark.sql.kyuubi.SparkDatasetHelper
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -213,7 +214,7 @@ class ArrowBasedExecuteStatement(
}
}

def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) {
private def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) {
executeArrowBatchCollect(df).getOrElse {
SparkDatasetHelper.toArrowBatchRdd(df).collect()
}
Expand All @@ -223,17 +224,16 @@ class ArrowBasedExecuteStatement(
df.queryExecution.executedPlan match {
case collectLimit @ CollectLimitExec(limit, _) =>
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
val maxRecordsPerBatch = spark.conf.getOption(
"spark.sql.execution.arrow.maxRecordsPerBatch").map(_.toInt).getOrElse(10000)
// val maxBatchSize =
// (spark.sessionState.conf.getConf(SPARK_CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong
val maxBatchSize = 1024 * 1024 * 4
val batches = ArrowCollectLimitExec.takeAsArrowBatches(
val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch

val batches = ArrowCollectUtils.takeAsArrowBatches(
collectLimit,
df.schema,
maxRecordsPerBatch,
maxBatchSize,
timeZoneId)

// note that the number of rows in the returned arrow batches may be >= `limit`, performing
// the slicing operation of result
val result = ArrayBuffer[Array[Byte]]()
var i = 0
var rest = limit
Expand All @@ -244,7 +244,7 @@ class ArrowBasedExecuteStatement(
// returned ArrowRecordBatch has less than `limit` row count, safety to do conversion
rest -= size.toInt
} else { // size > rest
result += KyuubiArrowUtils.sliceV2(df.schema, timeZoneId, batch, 0, rest)
result += KyuubiArrowUtils.slice(df.schema, timeZoneId, batch, 0, rest)
rest = 0
}
i += 1
Expand All @@ -263,4 +263,12 @@ class ArrowBasedExecuteStatement(
SparkDatasetHelper.convertTopLevelComplexTypeToHiveString(df, timestampAsString)
}

private lazy val maxBatchSize: Long = {
// respect spark connect config
spark.sparkContext.getConf.getOption("spark.connect.grpc.arrow.maxBatchSize")
.orElse(Option("4m"))
.map(JavaUtils.byteStringAs(_, ByteUnit.MiB))
.get
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,60 @@

package org.apache.spark.sql.execution.arrow

import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.execution.{CollectLimitExec, TakeOrderedAndProjectExec}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.execution.CollectLimitExec

object ArrowCollectLimitExec extends SQLConfHelper {
object ArrowCollectUtils extends SQLConfHelper {

type Batch = (Array[Byte], Long)

/**
* Forked from `org.apache.spark.sql.execution.SparkPlan#executeTake()`, the algorithm can be
* summarized in the following steps:
* 1. If the limit specified in the CollectLimitExec object is 0, the function returns an empty
* array of batches.
* 2. Otherwise, execute the child query plan of the CollectLimitExec object to obtain an RDD of
* data to collect.
* 3. Use an iterative approach to collect data in batches until the specified limit is reached.
* In each iteration, it selects a subset of the partitions of the RDD to scan and tries to
* collect data from them.
* 4. For each partition subset, we use the runJob method of the Spark context to execute a
* closure that scans the partition data and converts it to Arrow batches.
* 5. Check if the collected data reaches the specified limit. If not, it selects another subset
* of partitions to scan and repeats the process until the limit is reached or all partitions
* have been scanned.
* 6. Return an array of all the collected Arrow batches.
*
* Note that:
* 1. The returned Arrow batches row count >= limit, if the input df has more than the `limit`
* row count
* 2. We don't implement the `takeFromEnd` logical
*
* @return
*/
def takeAsArrowBatches(
collectLimitExec: CollectLimitExec,
schema: StructType,
maxRecordsPerBatch: Long,
maxEstimatedBatchSize: Long,
timeZoneId: String): Array[Batch] = {
val n = collectLimitExec.limit
// TODO
val takeFromEnd = false
val schema = collectLimitExec.schema
if (n == 0) {
return new Array[Batch](0)
} else {
val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2)
// // TODO: refactor and reuse the code from RDD's take()
// TODO: refactor and reuse the code from RDD's take()
val childRDD = collectLimitExec.child.execute()
val buf = if (takeFromEnd) new ListBuffer[Batch] else new ArrayBuffer[Batch]
val buf = new ArrayBuffer[Batch]
var bufferedRowSize = 0L
val totalParts = childRDD.partitions.length
var partsScanned = 0
while (bufferedRowSize < n && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
// var numPartsToTry = conf.limitInitialNumPartitions
var numPartsToTry = 1
var numPartsToTry = limitInitialNumPartitions
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, multiply by
// limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need
Expand All @@ -65,63 +85,49 @@ object ArrowCollectLimitExec extends SQLConfHelper {
}
}

val parts = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val partsToScan = if (takeFromEnd) {
// Reverse partitions to scan. So, if parts was [1, 2, 3] in 200 partitions (0 to 199),
// it becomes [198, 197, 196].
parts.map(p => (totalParts - 1) - p)
} else {
parts
}
val partsToScan =
partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)

val sc = collectLimitExec.session.sparkContext
val res = sc.runJob(
childRDD,
(it: Iterator[InternalRow]) => {
val batches = ArrowConvertersHelper.toBatchWithSchemaIterator(
val batches = ArrowConvertersHelper.toBatchIterator(
it,
schema,
maxRecordsPerBatch,
maxEstimatedBatchSize,
collectLimitExec.limit,
n,
timeZoneId)
batches.map(b => b -> batches.rowCountInLastBatch).toArray
},
partsToScan)

var i = 0
if (takeFromEnd) {
// while (buf.length < n && i < res.length) {
// val rows = decodeUnsafeRows(res(i)._2)
// if (n - buf.length >= res(i)._1) {
// buf.prepend(rows.toArray[InternalRow]: _*)
// } else {
// val dropUntil = res(i)._1 - (n - buf.length)
// // Same as Iterator.drop but this only takes a long.
// var j: Long = 0L
// while (j < dropUntil) { rows.next(); j += 1L}
// buf.prepend(rows.toArray[InternalRow]: _*)
// }
// i += 1
// }
} else {
while (bufferedRowSize < n && i < res.length) {
var j = 0
val batches = res(i)
while (j < batches.length && n > bufferedRowSize) {
val batch = batches(j)
val (_, batchSize) = batch
buf += batch
bufferedRowSize += batchSize
j += 1
}
i += 1
while (bufferedRowSize < n && i < res.length) {
var j = 0
val batches = res(i)
while (j < batches.length && n > bufferedRowSize) {
val batch = batches(j)
val (_, batchSize) = batch
buf += batch
bufferedRowSize += batchSize
j += 1
}
i += 1
}
partsScanned += partsToScan.size
}

buf.toArray
}
}

/**
* Spark introduced the config `spark.sql.limit.initialNumPartitions` since 3.4.0. see SPARK-40211
*/
def limitInitialNumPartitions: Int = {
conf.getConfString("spark.sql.limit.initialNumPartitions", "1")
.toInt
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,22 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.{SizeEstimator, Utils}
import org.apache.spark.util.Utils

object ArrowConvertersHelper extends Logging {

/**
* Convert the input rows into fully contained arrow batches.
* Different from [[toBatchIterator]], each output arrow batch starts with the schema.
* Different from [[org.apache.spark.sql.execution.arrow.ArrowConvertersHelper.toBatchIterator]],
* each output arrow batch contains this batch row count.
*/
private[sql] def toBatchWithSchemaIterator(
def toBatchIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Long,
maxEstimatedBatchSize: Long,
limit: Long,
timeZoneId: String): ArrowBatchWithSchemaIterator = {
new ArrowBatchWithSchemaIterator(
timeZoneId: String): ArrowBatchIterator = {
new ArrowBatchIterator(
rowIter,
schema,
maxRecordsPerBatch,
Expand All @@ -54,7 +54,7 @@ object ArrowConvertersHelper extends Logging {
TaskContext.get)
}

private[sql] class ArrowBatchWithSchemaIterator(
private[sql] class ArrowBatchIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Long,
Expand Down Expand Up @@ -88,7 +88,6 @@ object ArrowConvertersHelper extends Logging {
false
}

private val arrowSchemaSize = SizeEstimator.estimate(arrowSchema)
var rowCountInLastBatch: Long = 0
var rowCount: Long = 0

Expand All @@ -97,11 +96,8 @@ object ArrowConvertersHelper extends Logging {
val writeChannel = new WriteChannel(Channels.newChannel(out))

rowCountInLastBatch = 0
// var estimatedBatchSize = arrowSchemaSize
var estimatedBatchSize = 0
cfmcgrady marked this conversation as resolved.
Show resolved Hide resolved
Utils.tryWithSafeFinally {
// Always write the schema.
// MessageSerializer.serialize(writeChannel, arrowSchema)

// Always write the first row.
while (rowIter.hasNext && (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,45 +20,20 @@ package org.apache.spark.sql.execution.arrow
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.channels.Channels

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader}
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel}
import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel}
import org.apache.arrow.vector.ipc.message.MessageSerializer
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils

object KyuubiArrowUtils {
val rootAllocator = new RootAllocator(Long.MaxValue)
.newChildAllocator("ReadIntTest", 0, Long.MaxValue)
// BufferAllocator allocator =
// ArrowUtils.rootAllocator.newChildAllocator("ReadIntTest", 0, Long.MAX_VALUE);
def slice(bytes: Array[Byte], start: Int, length: Int): Array[Byte] = {
val in = new ByteArrayInputStream(bytes)
val out = new ByteArrayOutputStream()

var reader: ArrowStreamReader = null
try {
reader = new ArrowStreamReader(in, rootAllocator)
// reader.getVectorSchemaRoot.getSchema
reader.loadNextBatch()
val root = reader.getVectorSchemaRoot.slice(start, length)
// val loader = new VectorLoader(root)
val writer = new ArrowStreamWriter(root, null, out)
writer.start()
writer.writeBatch()
writer.end()
writer.close()
out.toByteArray
} finally {
if (reader != null) {
reader.close()
}
in.close()
out.close()
}
}

def sliceV2(
private val rootAllocator =
ArrowUtils.rootAllocator.newChildAllocator(
s"to${this.getClass.getSimpleName}",
0,
Long.MaxValue)
def slice(
schema: StructType,
timeZoneId: String,
bytes: Array[Byte],
Expand All @@ -68,13 +43,6 @@ object KyuubiArrowUtils {
val out = new ByteArrayOutputStream()

try {
// reader = new ArrowStreamReader(in, rootAllocator)
// // reader.getVectorSchemaRoot.getSchema
// reader.loadNextBatch()
// println("bytes......" + bytes.length)
// println("rowCount......" + reader.getVectorSchemaRoot.getRowCount)
// val root = reader.getVectorSchemaRoot.slice(start, length)

val recordBatch = MessageSerializer.deserializeRecordBatch(
new ReadChannel(Channels.newChannel(in)),
rootAllocator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,34 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp

test("aa") {

val returnSize = Seq(
7,
10,
13,
20,
29)

withJdbcStatement() { statement =>
loadPartitionedTable()
returnSize.foreach { size =>
statement.executeQuery(s"SET kyuubi.operation.result.max.rows=$size")
val result = statement.executeQuery("select * from t_1")
for (i <- 0 until size) {
assert(result.next())
}
assert(!result.next())
}
}

withJdbcStatement() { statement =>
loadPartitionedTable()
val n = 17
statement.executeQuery(s"SET kyuubi.operation.result.max.rows=$n")
val result = statement.executeQuery("select * from t_1")
for (i <- 0 until n) {
assert(result.next())
returnSize.foreach { size =>
val result = statement.executeQuery(s"select * from t_1 limit $size")
for (i <- 0 until size) {
assert(result.next())
}
assert(!result.next())
}
assert(!result.next())
}
}

Expand Down