Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARROW] Arrow serialization should not introduce extra shuffle for outermost limit #4662

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a584943
arrow take
cfmcgrady Mar 23, 2023
8593d85
driver slice last batch
cfmcgrady Mar 24, 2023
0088671
refine
cfmcgrady Mar 29, 2023
ed8c692
refactor
cfmcgrady Apr 3, 2023
4212a89
refactor and add ut
cfmcgrady Apr 4, 2023
6c5b1eb
add ut
cfmcgrady Apr 4, 2023
ee5a756
revert unnecessarily changes
cfmcgrady Apr 4, 2023
4e7ca54
unnecessarily changes
cfmcgrady Apr 4, 2023
885cf2c
infer row size by schema.defaultSize
cfmcgrady Apr 4, 2023
25e4f05
add docs
cfmcgrady Apr 4, 2023
03d0747
address comment
cfmcgrady Apr 6, 2023
2286afc
reflective calla AdaptiveSparkPlanExec.finalPhysicalPlan
cfmcgrady Apr 6, 2023
81886f0
address comment
cfmcgrady Apr 6, 2023
e3bf84c
refactor
cfmcgrady Apr 6, 2023
d70aee3
SparkPlan.session -> SparkSession.active to adapt Spark-3.1.x
cfmcgrady Apr 6, 2023
4cef204
SparkArrowbasedOperationSuite adapt Spark-3.1.x
cfmcgrady Apr 6, 2023
573a262
fix
cfmcgrady Apr 6, 2023
c83cf3f
SparkArrowbasedOperationSuite adapt Spark-3.1.x
cfmcgrady Apr 6, 2023
9ffb44f
make toBatchIterator private
cfmcgrady Apr 6, 2023
b72bc6f
add offset support to adapt Spark-3.4.x
cfmcgrady Apr 6, 2023
22cc70f
add ut
cfmcgrady Apr 6, 2023
8280783
add `isStaticConfigKey` to adapt Spark-3.1.x
cfmcgrady Apr 7, 2023
6d596fc
address comment
cfmcgrady Apr 7, 2023
6064ab9
limit = 0 test case
cfmcgrady Apr 7, 2023
3700839
SparkArrowbasedOperationSuite adapt Spark-3.1.x
cfmcgrady Apr 7, 2023
facc13f
exclude rule OptimizeLimitZero
cfmcgrady Apr 7, 2023
130bcb1
finally close
cfmcgrady Apr 7, 2023
82c912e
close vector
cfmcgrady Apr 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add ut
  • Loading branch information
cfmcgrady committed Apr 6, 2023
commit 22cc70fbae969d8a1d1d5afaf7d3a4d7301cc9e0
7 changes: 7 additions & 0 deletions externals/kyuubi-spark-sql-engine/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
cfmcgrady marked this conversation as resolved.
Show resolved Hide resolved
<type>test-jar</type>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-repl_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.kyuubi
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.{ByteUnit, JavaUtils}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
Expand All @@ -33,7 +34,7 @@ import org.apache.kyuubi.engine.spark.KyuubiSparkUtil
import org.apache.kyuubi.engine.spark.schema.RowSet
import org.apache.kyuubi.reflection.DynMethods

object SparkDatasetHelper {
object SparkDatasetHelper extends Logging {

def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) {
executeArrowBatchCollect(df.queryExecution.executedPlan)
Expand All @@ -43,8 +44,13 @@ object SparkDatasetHelper {
case adaptiveSparkPlan: AdaptiveSparkPlanExec =>
executeArrowBatchCollect(finalPhysicalPlan(adaptiveSparkPlan))
// TODO: avoid extra shuffle if `offset` > 0
case collectLimit: CollectLimitExec if offset(collectLimit) <= 0 =>
case collectLimit: CollectLimitExec if offset(collectLimit) > 0 =>
logWarning("unsupported offset > 0, an extra shuffle will be introduced.")
toArrowBatchRdd(collectLimit).collect()
case collectLimit: CollectLimitExec if collectLimit.limit >= 0 =>
doCollectLimit(collectLimit)
case collectLimit: CollectLimitExec if collectLimit.limit < 0 =>
executeArrowBatchCollect(collectLimit.child)
case plan: SparkPlan =>
toArrowBatchRdd(plan).collect()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@ import java.sql.Statement

import org.apache.spark.KyuubiSparkContextHelper
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.{QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution}
import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution, SparkPlan}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.kyuubi.SparkDatasetHelper
import org.apache.spark.sql.util.QueryExecutionListener

import org.apache.kyuubi.KyuubiException
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.engine.spark.{SparkSQLEngine, WithSparkSQLEngine}
import org.apache.kyuubi.engine.spark.session.SparkSessionImpl
Expand Down Expand Up @@ -150,57 +154,111 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
10, // equal to one partition
13, // between one and two partitions, run two jobs
20, // equal to two partitions
29) // between two and three partitions

// aqe
// outermost AdaptiveSparkPlanExec
spark.range(1000)
.repartitionByRange(100, col("id"))
.createOrReplaceTempView("t_1")
spark.sql("select * from t_1")
.foreachPartition { p: Iterator[Row] =>
assert(p.length == 10)
()
}
returnSize.foreach { size =>
val df = spark.sql(s"select * from t_1 limit $size")
val headPlan = df.queryExecution.executedPlan.collectLeaves().head
if (SPARK_ENGINE_RUNTIME_VERSION >= "3.2") {
assert(headPlan.isInstanceOf[AdaptiveSparkPlanExec])
val finalPhysicalPlan =
SparkDatasetHelper.finalPhysicalPlan(headPlan.asInstanceOf[AdaptiveSparkPlanExec])
assert(finalPhysicalPlan.isInstanceOf[CollectLimitExec])
}
29, // between two and three partitions
1000, // all partitions
1001) // more than total row count
// -1) // all
cfmcgrady marked this conversation as resolved.
Show resolved Hide resolved

withSQLConf(
SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
"org.apache.spark.sql.catalyst.optimizer.EliminateLimits",
SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key ->
"org.apache.spark.sql.catalyst.optimizer.EliminateLimits") {
// aqe
// outermost AdaptiveSparkPlanExec
spark.range(1000)
.repartitionByRange(100, col("id"))
.createOrReplaceTempView("t_1")
spark.sql("select * from t_1")
.foreachPartition { p: Iterator[Row] =>
assert(p.length == 10)
()
}
returnSize.foreach { size =>
val df = spark.sql(s"select * from t_1 limit $size")
val headPlan = df.queryExecution.executedPlan.collectLeaves().head
if (SPARK_ENGINE_RUNTIME_VERSION >= "3.2") {
assert(headPlan.isInstanceOf[AdaptiveSparkPlanExec])
val finalPhysicalPlan =
SparkDatasetHelper.finalPhysicalPlan(headPlan.asInstanceOf[AdaptiveSparkPlanExec])
assert(finalPhysicalPlan.isInstanceOf[CollectLimitExec])
}

val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution.executedPlan)
val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution
.executedPlan)

val rows = KyuubiArrowConverters.fromBatchIterator(
arrowBinary.iterator,
df.schema,
"",
KyuubiSparkContextHelper.dummyTaskContext())
if (size > 1000) {
assert(rows.size == 1000)
} else {
assert(rows.size == size)
}
}

val rows = KyuubiArrowConverters.fromBatchIterator(
arrowBinary.iterator,
df.schema,
"",
KyuubiSparkContextHelper.dummyTaskContext())
assert(rows.size == size)
// outermost CollectLimitExec
spark.range(0, 1000, 1, numPartitions = 100)
.createOrReplaceTempView("t_2")
spark.sql("select * from t_2")
.foreachPartition { p: Iterator[Row] =>
assert(p.length == 10)
()
}
returnSize.foreach { size =>
val df = spark.sql(s"select * from t_2 limit $size")
val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[CollectLimitExec])
val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution
.executedPlan)
val rows = KyuubiArrowConverters.fromBatchIterator(
arrowBinary.iterator,
df.schema,
"",
KyuubiSparkContextHelper.dummyTaskContext())
if (size > 1000) {
assert(rows.size == 1000)
} else {
assert(rows.size == size)
}
}
}
}

// outermost CollectLimitExec
spark.range(0, 1000, 1, numPartitions = 100)
.createOrReplaceTempView("t_2")
spark.sql("select * from t_2")
.foreachPartition { p: Iterator[Row] =>
assert(p.length == 10)
()
}
returnSize.foreach { size =>
val df = spark.sql(s"select * from t_2 limit $size")
val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[CollectLimitExec])
val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution.executedPlan)
val rows = KyuubiArrowConverters.fromBatchIterator(
arrowBinary.iterator,
df.schema,
"",
KyuubiSparkContextHelper.dummyTaskContext())
assert(rows.size == size)
test("aqe should work properly") {

val s = spark
import s.implicits._

spark.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString))).toDF()
.createOrReplaceTempView("testData")
spark.sparkContext.parallelize(
TestData2(1, 1) ::
TestData2(1, 2) ::
TestData2(2, 1) ::
TestData2(2, 2) ::
TestData2(3, 1) ::
TestData2(3, 2) :: Nil,
2).toDF()
.createOrReplaceTempView("testData2")

withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.SHUFFLE_PARTITIONS.key -> "5",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"""
|SELECT * FROM(
| SELECT * FROM testData join testData2 ON key = a where value = '1'
|) LIMIT 1
|""".stripMargin)
val smj = plan.collect { case smj: SortMergeJoinExec => smj }
val bhj = adaptivePlan.collect { case bhj: BroadcastHashJoinExec => bhj }
assert(smj.size == 1)
assert(bhj.size == 1)
}
}

Expand Down Expand Up @@ -315,4 +373,53 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
.map(_.asInstanceOf[SparkSessionImpl].spark)
.foreach(op(_))
}

private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, SparkPlan) = {
val dfAdaptive = spark.sql(query)
val planBefore = dfAdaptive.queryExecution.executedPlan
val result = dfAdaptive.collect()
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
val df = spark.sql(query)
QueryTest.checkAnswer(df, df.collect().toSeq)
}
val planAfter = dfAdaptive.queryExecution.executedPlan
val adaptivePlan = planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
val exchanges = adaptivePlan.collect {
case e: Exchange => e
}
assert(exchanges.isEmpty, "The final plan should not contain any Exchange node.")
(dfAdaptive.queryExecution.sparkPlan, adaptivePlan)
}

/**
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL
* configurations.
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val conf = SQLConf.get
val (keys, values) = pairs.unzip
val currentValues = keys.map { key =>
if (conf.contains(key)) {
Some(conf.getConfString(key))
} else {
None
}
}
(keys, values).zipped.foreach { (k, v) =>
if (SQLConf.isStaticConfigKey(k)) {
throw new KyuubiException(s"Cannot modify the value of a static config: $k")
}
conf.setConfString(k, v)
}
try f
finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => conf.setConfString(key, value)
case (key, None) => conf.unsetConf(key)
}
}
}
}

case class TestData(key: Int, value: String)
case class TestData2(a: Int, b: Int)