diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index a308c4748f4..a1f303a2623 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -25,9 +25,8 @@ import org.apache.spark.network.util.{ByteUnit, JavaUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, LogicalPlan} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils -import org.apache.spark.sql.execution.{CollectLimitExec, HiveResult, LocalTableScanExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{CollectLimitExec, HiveResult, LocalTableScanExec, SparkPlan, SQLExecution, TakeOrderedAndProjectExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -295,8 +294,10 @@ object SparkDatasetHelper extends Logging { SQLMetrics.postDriverMetricUpdates(sc, executionId, metrics.values.toSeq) } - private[kyuubi] def logicalPlanLimit(plan: LogicalPlan): Option[Long] = plan match { - case globalLimit: GlobalLimit => globalLimit.maxRows + private[kyuubi] def planLimit(plan: SparkPlan): Option[Int] = plan match { + case tp: TakeOrderedAndProjectExec => Option(tp.limit) + case c: CollectLimitExec => Option(c.limit) + case ap: AdaptiveSparkPlanExec => planLimit(ap.inputPlan) case _ => None } @@ -304,7 +305,7 @@ object SparkDatasetHelper extends Logging { if (isCommandExec(result.queryExecution.executedPlan.nodeName)) { return false } - val finalLimit = logicalPlanLimit(result.queryExecution.logical) match { + val finalLimit = planLimit(result.queryExecution.sparkPlan) match { case Some(limit) if resultMaxRows > 0 => math.min(limit, resultMaxRows) case Some(limit) => limit case None => resultMaxRows diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelperSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelperSuite.scala index 036b4dfd54f..8e51484b176 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelperSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelperSuite.scala @@ -24,7 +24,7 @@ import org.apache.kyuubi.engine.spark.WithSparkSQLEngine class SparkDatasetHelperSuite extends WithSparkSQLEngine { override def withKyuubiConf: Map[String, String] = Map.empty - test("get limit from logical plan") { + test("get limit from spark plan") { Seq(true, false).foreach { aqe => val topKThreshold = 3 spark.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, aqe) @@ -32,13 +32,14 @@ class SparkDatasetHelperSuite extends WithSparkSQLEngine { spark.sql("CREATE OR REPLACE TEMPORARY VIEW tv AS" + " SELECT * FROM VALUES(1),(2),(3),(4) AS t(id)") - val topKStatement = s"SELECT * FROM tv ORDER BY id LIMIT ${topKThreshold - 1}" - assert(SparkDatasetHelper.logicalPlanLimit( - spark.sql(topKStatement).queryExecution.logical) === Option(topKThreshold - 1)) + val topKStatement = s"SELECT * FROM(SELECT * FROM tv ORDER BY id LIMIT ${topKThreshold - 1})" + assert(SparkDatasetHelper.planLimit( + spark.sql(topKStatement).queryExecution.sparkPlan) === Option(topKThreshold - 1)) - val collectLimitStatement = s"SELECT * FROM tv ORDER BY id LIMIT $topKThreshold" - assert(SparkDatasetHelper.logicalPlanLimit( - spark.sql(collectLimitStatement).queryExecution.logical) === Option(topKThreshold)) + val collectLimitStatement = + s"SELECT * FROM (SELECT * FROM tv ORDER BY id LIMIT $topKThreshold)" + assert(SparkDatasetHelper.planLimit( + spark.sql(collectLimitStatement).queryExecution.sparkPlan) === Option(topKThreshold)) } } }