Skip to content

Commit

Permalink
add ut
Browse files Browse the repository at this point in the history
  • Loading branch information
cfmcgrady committed Apr 4, 2023
1 parent 4212a89 commit 6c5b1eb
Showing 1 changed file with 55 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.kyuubi.engine.spark.operation
import java.sql.Statement

import org.apache.spark.KyuubiSparkContextHelper
import org.apache.spark.sql.Row
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
Expand Down Expand Up @@ -143,8 +144,7 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
assert(metrics("numOutputRows").value === 1)
}

test("aa") {

test("SparkDatasetHelper.executeArrowBatchCollect should return expect row count") {
val returnSize = Seq(
7, // less than one partition
10, // equal to one partition
Expand Down Expand Up @@ -202,6 +202,25 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
}
}

test("arrow serialization should not introduce extra shuffle for outermost limit") {
var numStages = 0
val listener = new SparkListener {
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
numStages = jobStart.stageInfos.length
}
}
withJdbcStatement() { statement =>
withSparkListener(listener) {
withPartitionedTable("t_3") {
statement.executeQuery("select * from t_3 limit 1000")
}
KyuubiSparkContextHelper.waitListenerBus(spark)
}
}
// Should be only one stage since there is no shuffle.
assert(numStages == 1)
}

private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
val query =
s"""
Expand Down Expand Up @@ -241,4 +260,37 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
.allSessions()
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.unregister(listener))
}

private def withSparkListener[T](listener: SparkListener)(body: => T): T = {
withAllSessions(s => s.sparkContext.addSparkListener(listener))
try {
body
} finally {
withAllSessions(s => s.sparkContext.removeSparkListener(listener))
}

}

private def withPartitionedTable[T](viewName: String)(body: => T): T = {
withAllSessions { spark =>
spark.range(0, 1000, 1, numPartitions = 100)
.createOrReplaceTempView(viewName)
}
try {
body
} finally {
withAllSessions { spark =>
spark.sql(s"DROP VIEW IF EXISTS $viewName")
}
}
}

private def withAllSessions(op: SparkSession => Unit): Unit = {
SparkSQLEngine.currentEngine.get
.backendService
.sessionManager
.allSessions()
.map(_.asInstanceOf[SparkSessionImpl].spark)
.foreach(op(_))
}
}

0 comments on commit 6c5b1eb

Please sign in to comment.