Skip to content

[SPARK-13306][SQL] Addendum to uncorrelated scalar subquery #11285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,12 @@ class Analyzer(
}
substituted.getOrElse(u)
case other =>
// This can't be done in ResolveSubquery because that does not know the CTE.
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
other transformExpressions {
case e: SubqueryExpression =>
e.withNewPlan(substituteCTE(e.query, cteRelations))
}
}

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,3 @@ case class Literal protected (value: Any, dataType: DataType)
case _ => value.toString
}
}

// TODO: Specialize
case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true)
extends LeafExpression with CodegenFallback {

def update(expression: Expression, input: InternalRow): Unit = {
value = expression.eval(input)
}

override def eval(input: InternalRow): Any = value
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ abstract class SubqueryExpression extends LeafExpression {
}

/**
* A subquery that will return only one row and one column.
*
* This will be converted into [[execution.ScalarSubquery]] during physical planning.
* A subquery that will return only one row and one column. This will be converted into a physical
* scalar subquery during planning.
*
* Note: `exprId` is used to have unique name in explain string output.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* populated by the query planning infrastructure.
*/
@transient
protected[spark] final val sqlContext = SQLContext.getActive().getOrElse(null)
protected[spark] final val sqlContext = SQLContext.getActive().orNull

protected def sparkContext = sqlContext.sparkContext

Expand Down Expand Up @@ -120,44 +120,49 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
}

// All the subqueries and their Future of results.
@transient private val queryResults = ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]()
/**
* List of (uncorrelated scalar subquery, future holding the subquery result) for this plan node.
* This list is populated by [[prepareSubqueries]], which is called in [[prepare]].
*/
@transient
private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]

/**
* Collects all the subqueries and create a Future to take the first two rows of them.
* Finds scalar subquery expressions in this plan node and starts evaluating them.
* The list of subqueries are added to [[subqueryResults]].
*/
protected def prepareSubqueries(): Unit = {
val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e})
allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e =>
val futureResult = Future {
// We only need the first row, try to take two rows so we can throw an exception if there
// are more than one rows returned.
// Each subquery should return only one row (and one column). We take two here and throws
// an exception later if the number of rows is greater than one.
e.executedPlan.executeTake(2)
}(SparkPlan.subqueryExecutionContext)
queryResults += e -> futureResult
subqueryResults += e -> futureResult
}
}

/**
* Waits for all the subqueries to finish and updates the results.
* Blocks the thread until all subqueries finish evaluation and update the results.
*/
protected def waitForSubqueries(): Unit = {
// fill in the result of subqueries
queryResults.foreach {
case (e, futureResult) =>
val rows = Await.result(futureResult, Duration.Inf)
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1, "Analyzer should make sure this only returns one column")
e.updateResult(rows(0).get(0, e.dataType))
} else {
// There is no rows returned, the result should be null.
e.updateResult(null)
}
subqueryResults.foreach { case (e, futureResult) =>
val rows = Await.result(futureResult, Duration.Inf)
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1,
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
e.updateResult(rows(0).get(0, e.dataType))
} else {
// If there is no rows returned, the result should be null.
e.updateResult(null)
}
}
queryResults.clear()
subqueryResults.clear()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ case class ScalarSubquery(
/**
* Convert the subquery from logical plan into executed plan.
*/
private[sql] case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressions {
case subquery: expressions.ScalarSubquery =>
Expand Down
61 changes: 30 additions & 31 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,65 +20,64 @@ package org.apache.spark.sql
import org.apache.spark.sql.test.SharedSQLContext

class SubquerySuite extends QueryTest with SharedSQLContext {
import testImplicits._

test("simple uncorrelated scalar subquery") {
assertResult(Array(Row(1))) {
sql("select (select 1 as b) as b").collect()
}

assertResult(Array(Row(1))) {
sql("with t2 as (select 1 as b, 2 as c) " +
"select a from (select 1 as a union all select 2 as a) t " +
"where a = (select max(b) from t2) ").collect()
}

assertResult(Array(Row(3))) {
sql("select (select (select 1) + 1) + 1").collect()
}

// more than one columns
val error = intercept[AnalysisException] {
sql("select (select 1, 2) as b").collect()
}
assert(error.message contains "Scalar subquery must return only one column, but got 2")

// more than one rows
val error2 = intercept[RuntimeException] {
sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect()
}
assert(error2.getMessage contains
"more than one row returned by a subquery used as an expression")

// string type
assertResult(Array(Row("s"))) {
sql("select (select 's' as s) as b").collect()
}
}

// zero rows
test("uncorrelated scalar subquery in CTE") {
assertResult(Array(Row(1))) {
sql("with t2 as (select 1 as b, 2 as c) " +
"select a from (select 1 as a union all select 2 as a) t " +
"where a = (select max(b) from t2) ").collect()
}
}

test("uncorrelated scalar subquery should return null if there is 0 rows") {
assertResult(Array(Row(null))) {
sql("select (select 's' as s limit 0) as b").collect()
}
}

test("uncorrelated scalar subquery on testData") {
// initialize test Data
testData
test("runtime error when the number of rows is greater than 1") {
val error2 = intercept[RuntimeException] {
sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect()
}
assert(error2.getMessage.contains(
"more than one row returned by a subquery used as an expression"))
}

test("uncorrelated scalar subquery on a DataFrame generated query") {
val df = Seq((1, "one"), (2, "two"), (3, "three")).toDF("key", "value")
df.registerTempTable("subqueryData")

assertResult(Array(Row(5))) {
sql("select (select key from testData where key > 3 limit 1) + 1").collect()
assertResult(Array(Row(4))) {
sql("select (select key from subqueryData where key > 2 order by key limit 1) + 1").collect()
}

assertResult(Array(Row(-100))) {
sql("select -(select max(key) from testData)").collect()
assertResult(Array(Row(-3))) {
sql("select -(select max(key) from subqueryData)").collect()
}

assertResult(Array(Row(null))) {
sql("select (select value from testData limit 0)").collect()
sql("select (select value from subqueryData limit 0)").collect()
}

assertResult(Array(Row("99"))) {
sql("select (select min(value) from testData" +
" where key = (select max(key) from testData) - 1)").collect()
assertResult(Array(Row("two"))) {
sql("select (select min(value) from subqueryData" +
" where key = (select max(key) from subqueryData) - 1)").collect()
}
}
}