Skip to content

[SPARK-16355] [SPARK-16354] [SQL] Fix Bugs When LIMIT/TABLESAMPLE is Non-foldable, Zero or Negative #14034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ trait CheckAnalysis extends PredicateHelper {
}).length > 1
}

private def checkLimitClause(limitExpr: Expression): Unit = {
limitExpr match {
case e if !e.foldable => failAnalysis(
"The limit expression must evaluate to a constant value, but got " +
limitExpr.sql)
case e if e.dataType != IntegerType => failAnalysis(
s"The limit expression must be integer type, but got " +
e.dataType.simpleString)
case e if e.eval().asInstanceOf[Int] < 0 => failAnalysis(
"The limit expression must be equal to or greater than 0, but got " +
e.eval().asInstanceOf[Int])
case e => // OK
}
}

def checkAnalysis(plan: LogicalPlan): Unit = {
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
Expand Down Expand Up @@ -251,6 +266,10 @@ trait CheckAnalysis extends PredicateHelper {
s"but one table has '${firstError.output.length}' columns and another table has " +
s"'${s.children.head.output.length}' columns")

case GlobalLimit(limitExpr, _) => checkLimitClause(limitExpr)

case LocalLimit(limitExpr, _) => checkLimitClause(limitExpr)

case p if p.expressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) =>
p match {
case _: Filter | _: Aggregate | _: Project => // Ok
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,13 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
}
override lazy val statistics: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
val sizeInBytes = if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
1
} else {
(limit: Long) * output.map(a => a.dataType.defaultSize).sum
}
child.statistics.copy(sizeInBytes = sizeInBytes)
}
}
Expand All @@ -675,7 +681,13 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
}
override lazy val statistics: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
val sizeInBytes = if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
1
} else {
(limit: Long) * output.map(a => a.dataType.defaultSize).sum
}
child.statistics.copy(sizeInBytes = sizeInBytes)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,12 @@ class AnalysisErrorSuite extends AnalysisTest {
"Generators are not supported outside the SELECT clause, but got: Sort" :: Nil
)

errorTest(
"num_rows in limit clause must be equal to or greater than 0",
listRelation.limit(-1),
"The limit expression must be equal to or greater than 0, but got -1" :: Nil
)

errorTest(
"more than one generators in SELECT",
listRelation.select(Explode('list), Explode('list)),
Expand Down
37 changes: 35 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -660,18 +660,51 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {

test("limit") {
checkAnswer(
sql("SELECT * FROM testData LIMIT 10"),
sql("SELECT * FROM testData LIMIT 9 + 1"),
testData.take(10).toSeq)

checkAnswer(
sql("SELECT * FROM arrayData LIMIT 1"),
sql("SELECT * FROM arrayData LIMIT CAST(1 AS Integer)"),
arrayData.collect().take(1).map(Row.fromTuple).toSeq)

checkAnswer(
sql("SELECT * FROM mapData LIMIT 1"),
mapData.collect().take(1).map(Row.fromTuple).toSeq)
}

test("non-foldable expressions in LIMIT") {
val e = intercept[AnalysisException] {
sql("SELECT * FROM testData LIMIT key > 3")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what will happen if the type is wrong? e.g. LIMIT true

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good question! : ) Now, the exception we issued is not good:

java.lang.Boolean cannot be cast to java.lang.Integer
java.lang.ClassCastException: java.lang.Boolean cannot be cast to java.lang.Integer
    at scala.runtime.BoxesRunTime.unboxToInt(BoxesRunTime.java:101)

Let me fix it and throw a more reasonable exception:

number_rows in limit clause cannot be cast to integer:true;

}.getMessage
assert(e.contains("The limit expression must evaluate to a constant value, " +
"but got (testdata.`key` > 3)"))
}

test("Expressions in limit clause are not integer") {
var e = intercept[AnalysisException] {
sql("SELECT * FROM testData LIMIT true")
}.getMessage
assert(e.contains("The limit expression must be integer type, but got boolean"))

e = intercept[AnalysisException] {
sql("SELECT * FROM testData LIMIT 'a'")
}.getMessage
assert(e.contains("The limit expression must be integer type, but got string"))
}

test("negative in LIMIT or TABLESAMPLE") {
val expected = "The limit expression must be equal to or greater than 0, but got -1"
var e = intercept[AnalysisException] {
sql("SELECT * FROM testData TABLESAMPLE (-1 rows)")
}.getMessage
assert(e.contains(expected))

e = intercept[AnalysisException] {
sql("SELECT * FROM testData LIMIT -1")
}.getMessage
assert(e.contains(expected))
}

test("CTE feature") {
checkAnswer(
sql("with q1 as (select * from testData limit 10) select * from q1"),
Expand Down
44 changes: 44 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._

class StatisticsSuite extends QueryTest with SharedSQLContext {
import testImplicits._

test("SPARK-15392: DataFrame created from RDD should not be broadcasted") {
val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
Expand All @@ -31,4 +33,46 @@ class StatisticsSuite extends QueryTest with SharedSQLContext {
spark.sessionState.conf.autoBroadcastJoinThreshold)
}

test("estimates the size of limit") {
withTempTable("test") {
Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
.createOrReplaceTempView("test")
Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) =>
val df = sql(s"""SELECT * FROM test limit $limit""")

val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit =>
g.statistics.sizeInBytes
}
assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizesGlobalLimit.head === BigInt(expected),
s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}")

val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit =>
l.statistics.sizeInBytes
}
assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizesLocalLimit.head === BigInt(expected),
s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}")
}
}
}

test("estimates the size of a limit 0 on outer join") {
withTempTable("test") {
Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
.createOrReplaceTempView("test")
val df1 = spark.table("test")
val df2 = spark.table("test").limit(0)
val df = df1.join(df2, Seq("k"), "left")

val sizes = df.queryExecution.analyzed.collect { case g: Join =>
g.statistics.sizeInBytes
}

assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
assert(sizes.head === BigInt(96),
s"expected exact size 96 for table 'test', got: ${sizes.head}")
}
}

}