Skip to content

Commit de512f6

Browse files
committed
use test table inline
1 parent c9bd987 commit de512f6

File tree

2 files changed

+22
-27
lines changed

2 files changed

+22
-27
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,8 @@ abstract class SubqueryExpression extends LeafExpression {
4545
}
4646

4747
/**
48-
* A subquery that will return only one row and one column.
49-
*
50-
* This will be converted into [[execution.ScalarSubquery]] during physical planning.
48+
* A subquery that will return only one row and one column. This will be converted into a physical
49+
* scalar subquery during planning.
5150
*
5251
* Note: `exprId` is used to have unique name in explain string output.
5352
*/

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,13 @@ package org.apache.spark.sql
2020
import org.apache.spark.sql.test.SharedSQLContext
2121

2222
class SubquerySuite extends QueryTest with SharedSQLContext {
23+
import testImplicits._
2324

2425
test("simple uncorrelated scalar subquery") {
2526
assertResult(Array(Row(1))) {
2627
sql("select (select 1 as b) as b").collect()
2728
}
2829

29-
assertResult(Array(Row(1))) {
30-
sql("with t2 as (select 1 as b, 2 as c) " +
31-
"select a from (select 1 as a union all select 2 as a) t " +
32-
"where a = (select max(b) from t2) ").collect()
33-
}
34-
3530
assertResult(Array(Row(3))) {
3631
sql("select (select (select 1) + 1) + 1").collect()
3732
}
@@ -42,17 +37,18 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
4237
}
4338
}
4439

45-
test("uncorrelated scalar subquery should return null if there is 0 rows") {
46-
assertResult(Array(Row(null))) {
47-
sql("select (select 's' as s limit 0) as b").collect()
40+
test("uncorrelated scalar subquery in CTE") {
41+
assertResult(Array(Row(1))) {
42+
sql("with t2 as (select 1 as b, 2 as c) " +
43+
"select a from (select 1 as a union all select 2 as a) t " +
44+
"where a = (select max(b) from t2) ").collect()
4845
}
4946
}
5047

51-
test("analysis error when the number of columns is not 1") {
52-
val error = intercept[AnalysisException] {
53-
sql("select (select 1, 2) as b").collect()
48+
test("uncorrelated scalar subquery should return null if there is 0 rows") {
49+
assertResult(Array(Row(null))) {
50+
sql("select (select 's' as s limit 0) as b").collect()
5451
}
55-
assert(error.message.contains("Scalar subquery must return only one column, but got 2"))
5652
}
5753

5854
test("runtime error when the number of rows is greater than 1") {
@@ -63,25 +59,25 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
6359
"more than one row returned by a subquery used as an expression"))
6460
}
6561

66-
test("uncorrelated scalar subquery on testData") {
67-
// initialize test Data
68-
testData
62+
test("uncorrelated scalar subquery on a DataFrame generated query") {
63+
val df = Seq((1, "one"), (2, "two"), (3, "three")).toDF("key", "value")
64+
df.registerTempTable("subqueryData")
6965

70-
assertResult(Array(Row(5))) {
71-
sql("select (select key from testData where key > 3 order by key limit 1) + 1").collect()
66+
assertResult(Array(Row(4))) {
67+
sql("select (select key from subqueryData where key > 2 order by key limit 1) + 1").collect()
7268
}
7369

74-
assertResult(Array(Row(-100))) {
75-
sql("select -(select max(key) from testData)").collect()
70+
assertResult(Array(Row(-3))) {
71+
sql("select -(select max(key) from subqueryData)").collect()
7672
}
7773

7874
assertResult(Array(Row(null))) {
79-
sql("select (select value from testData limit 0)").collect()
75+
sql("select (select value from subqueryData limit 0)").collect()
8076
}
8177

82-
assertResult(Array(Row("99"))) {
83-
sql("select (select min(value) from testData" +
84-
" where key = (select max(key) from testData) - 1)").collect()
78+
assertResult(Array(Row("two"))) {
79+
sql("select (select min(value) from subqueryData" +
80+
" where key = (select max(key) from subqueryData) - 1)").collect()
8581
}
8682
}
8783
}

0 commit comments

Comments
 (0)