Skip to content

Commit

Permalink
add ut
Browse files Browse the repository at this point in the history
  • Loading branch information
cfmcgrady committed Apr 6, 2023
1 parent b72bc6f commit 22cc70f
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 51 deletions.
7 changes: 7 additions & 0 deletions externals/kyuubi-spark-sql-engine/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<type>test-jar</type>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-repl_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.kyuubi
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.{ByteUnit, JavaUtils}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
Expand All @@ -33,7 +34,7 @@ import org.apache.kyuubi.engine.spark.KyuubiSparkUtil
import org.apache.kyuubi.engine.spark.schema.RowSet
import org.apache.kyuubi.reflection.DynMethods

object SparkDatasetHelper {
object SparkDatasetHelper extends Logging {

def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) {
executeArrowBatchCollect(df.queryExecution.executedPlan)
Expand All @@ -43,8 +44,13 @@ object SparkDatasetHelper {
case adaptiveSparkPlan: AdaptiveSparkPlanExec =>
executeArrowBatchCollect(finalPhysicalPlan(adaptiveSparkPlan))
// TODO: avoid extra shuffle if `offset` > 0
case collectLimit: CollectLimitExec if offset(collectLimit) <= 0 =>
case collectLimit: CollectLimitExec if offset(collectLimit) > 0 =>
logWarning("unsupported offset > 0, an extra shuffle will be introduced.")
toArrowBatchRdd(collectLimit).collect()
case collectLimit: CollectLimitExec if collectLimit.limit >= 0 =>
doCollectLimit(collectLimit)
case collectLimit: CollectLimitExec if collectLimit.limit < 0 =>
executeArrowBatchCollect(collectLimit.child)
case plan: SparkPlan =>
toArrowBatchRdd(plan).collect()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@ import java.sql.Statement

import org.apache.spark.KyuubiSparkContextHelper
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.{QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution}
import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution, SparkPlan}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.kyuubi.SparkDatasetHelper
import org.apache.spark.sql.util.QueryExecutionListener

import org.apache.kyuubi.KyuubiException
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.engine.spark.{SparkSQLEngine, WithSparkSQLEngine}
import org.apache.kyuubi.engine.spark.session.SparkSessionImpl
Expand Down Expand Up @@ -150,57 +154,111 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
10, // equal to one partition
13, // between one and two partitions, run two jobs
20, // equal to two partitions
29) // between two and three partitions

// aqe
// outermost AdaptiveSparkPlanExec
spark.range(1000)
.repartitionByRange(100, col("id"))
.createOrReplaceTempView("t_1")
spark.sql("select * from t_1")
.foreachPartition { p: Iterator[Row] =>
assert(p.length == 10)
()
}
returnSize.foreach { size =>
val df = spark.sql(s"select * from t_1 limit $size")
val headPlan = df.queryExecution.executedPlan.collectLeaves().head
if (SPARK_ENGINE_RUNTIME_VERSION >= "3.2") {
assert(headPlan.isInstanceOf[AdaptiveSparkPlanExec])
val finalPhysicalPlan =
SparkDatasetHelper.finalPhysicalPlan(headPlan.asInstanceOf[AdaptiveSparkPlanExec])
assert(finalPhysicalPlan.isInstanceOf[CollectLimitExec])
}
29, // between two and three partitions
1000, // all partitions
1001) // more than total row count
// -1) // all

withSQLConf(
SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
"org.apache.spark.sql.catalyst.optimizer.EliminateLimits",
SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key ->
"org.apache.spark.sql.catalyst.optimizer.EliminateLimits") {
// aqe
// outermost AdaptiveSparkPlanExec
spark.range(1000)
.repartitionByRange(100, col("id"))
.createOrReplaceTempView("t_1")
spark.sql("select * from t_1")
.foreachPartition { p: Iterator[Row] =>
assert(p.length == 10)
()
}
returnSize.foreach { size =>
val df = spark.sql(s"select * from t_1 limit $size")
val headPlan = df.queryExecution.executedPlan.collectLeaves().head
if (SPARK_ENGINE_RUNTIME_VERSION >= "3.2") {
assert(headPlan.isInstanceOf[AdaptiveSparkPlanExec])
val finalPhysicalPlan =
SparkDatasetHelper.finalPhysicalPlan(headPlan.asInstanceOf[AdaptiveSparkPlanExec])
assert(finalPhysicalPlan.isInstanceOf[CollectLimitExec])
}

val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution.executedPlan)
val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution
.executedPlan)

val rows = KyuubiArrowConverters.fromBatchIterator(
arrowBinary.iterator,
df.schema,
"",
KyuubiSparkContextHelper.dummyTaskContext())
if (size > 1000) {
assert(rows.size == 1000)
} else {
assert(rows.size == size)
}
}

val rows = KyuubiArrowConverters.fromBatchIterator(
arrowBinary.iterator,
df.schema,
"",
KyuubiSparkContextHelper.dummyTaskContext())
assert(rows.size == size)
// outermost CollectLimitExec
spark.range(0, 1000, 1, numPartitions = 100)
.createOrReplaceTempView("t_2")
spark.sql("select * from t_2")
.foreachPartition { p: Iterator[Row] =>
assert(p.length == 10)
()
}
returnSize.foreach { size =>
val df = spark.sql(s"select * from t_2 limit $size")
val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[CollectLimitExec])
val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution
.executedPlan)
val rows = KyuubiArrowConverters.fromBatchIterator(
arrowBinary.iterator,
df.schema,
"",
KyuubiSparkContextHelper.dummyTaskContext())
if (size > 1000) {
assert(rows.size == 1000)
} else {
assert(rows.size == size)
}
}
}
}

// outermost CollectLimitExec
spark.range(0, 1000, 1, numPartitions = 100)
.createOrReplaceTempView("t_2")
spark.sql("select * from t_2")
.foreachPartition { p: Iterator[Row] =>
assert(p.length == 10)
()
}
returnSize.foreach { size =>
val df = spark.sql(s"select * from t_2 limit $size")
val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[CollectLimitExec])
val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution.executedPlan)
val rows = KyuubiArrowConverters.fromBatchIterator(
arrowBinary.iterator,
df.schema,
"",
KyuubiSparkContextHelper.dummyTaskContext())
assert(rows.size == size)
test("aqe should work properly") {

val s = spark
import s.implicits._

spark.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString))).toDF()
.createOrReplaceTempView("testData")
spark.sparkContext.parallelize(
TestData2(1, 1) ::
TestData2(1, 2) ::
TestData2(2, 1) ::
TestData2(2, 2) ::
TestData2(3, 1) ::
TestData2(3, 2) :: Nil,
2).toDF()
.createOrReplaceTempView("testData2")

withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.SHUFFLE_PARTITIONS.key -> "5",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"""
|SELECT * FROM(
| SELECT * FROM testData join testData2 ON key = a where value = '1'
|) LIMIT 1
|""".stripMargin)
val smj = plan.collect { case smj: SortMergeJoinExec => smj }
val bhj = adaptivePlan.collect { case bhj: BroadcastHashJoinExec => bhj }
assert(smj.size == 1)
assert(bhj.size == 1)
}
}

Expand Down Expand Up @@ -315,4 +373,53 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
.map(_.asInstanceOf[SparkSessionImpl].spark)
.foreach(op(_))
}

private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, SparkPlan) = {
val dfAdaptive = spark.sql(query)
val planBefore = dfAdaptive.queryExecution.executedPlan
val result = dfAdaptive.collect()
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
val df = spark.sql(query)
QueryTest.checkAnswer(df, df.collect().toSeq)
}
val planAfter = dfAdaptive.queryExecution.executedPlan
val adaptivePlan = planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
val exchanges = adaptivePlan.collect {
case e: Exchange => e
}
assert(exchanges.isEmpty, "The final plan should not contain any Exchange node.")
(dfAdaptive.queryExecution.sparkPlan, adaptivePlan)
}

/**
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL
* configurations.
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val conf = SQLConf.get
val (keys, values) = pairs.unzip
val currentValues = keys.map { key =>
if (conf.contains(key)) {
Some(conf.getConfString(key))
} else {
None
}
}
(keys, values).zipped.foreach { (k, v) =>
if (SQLConf.isStaticConfigKey(k)) {
throw new KyuubiException(s"Cannot modify the value of a static config: $k")
}
conf.setConfString(k, v)
}
try f
finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => conf.setConfString(key, value)
case (key, None) => conf.unsetConf(key)
}
}
}
}

case class TestData(key: Int, value: String)
case class TestData2(a: Int, b: Int)

0 comments on commit 22cc70f

Please sign in to comment.