Skip to content

Commit af95455

Browse files
nsycahvanhovell
authored andcommitted
[SPARK-18863][SQL] Output non-aggregate expressions without GROUP BY in a subquery does not yield an error
## What changes were proposed in this pull request? This PR will report proper error messages when a subquery expression contain an invalid plan. This problem is fixed by calling CheckAnalysis for the plan inside a subquery. ## How was this patch tested? Existing tests and two new test cases on 2 forms of subquery, namely, scalar subquery and in/exists subquery. ```` -- TC 01.01 -- The column t2b in the SELECT of the subquery is invalid -- because it is neither an aggregate function nor a GROUP BY column. select t1a, t2b from t1, t2 where t1b = t2c and t2b = (select max(avg) from (select t2b, avg(t2b) avg from t2 where t2a = t1.t1b ) ) ; -- TC 01.02 -- Invalid due to the column t2b not part of the output from table t2. select * from t1 where t1a in (select min(t2a) from t2 group by t2c having t2c in (select max(t3c) from t3 group by t3b having t3b > t2b )) ; ```` Author: Nattavut Sutyanyong <nsy.can@gmail.com> Closes #16572 from nsyca/18863. (cherry picked from commit f1ddca5) Signed-off-by: Herman van Hovell <hvanhovell@databricks.com>
1 parent f391ad2 commit af95455

File tree

4 files changed

+168
-51
lines changed

4 files changed

+168
-51
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 56 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -117,66 +117,72 @@ trait CheckAnalysis extends PredicateHelper {
117117
failAnalysis(s"Window specification $s is not valid because $m")
118118
case None => w
119119
}
120-
case s @ ScalarSubquery(query, conditions, _)
120+
121+
case s @ ScalarSubquery(query, conditions, _) =>
121122
// If no correlation, the output must be exactly one column
122-
if (conditions.isEmpty && query.output.size != 1) =>
123+
if (conditions.isEmpty && query.output.size != 1) {
123124
failAnalysis(
124125
s"Scalar subquery must return only one column, but got ${query.output.size}")
126+
}
127+
else if (conditions.nonEmpty) {
128+
// Collect the columns from the subquery for further checking.
129+
var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains)
130+
131+
def checkAggregate(agg: Aggregate): Unit = {
132+
// Make sure correlated scalar subqueries contain one row for every outer row by
133+
// enforcing that they are aggregates containing exactly one aggregate expression.
134+
// The analyzer has already checked that subquery contained only one output column,
135+
// and added all the grouping expressions to the aggregate.
136+
val aggregates = agg.expressions.flatMap(_.collect {
137+
case a: AggregateExpression => a
138+
})
139+
if (aggregates.isEmpty) {
140+
failAnalysis("The output of a correlated scalar subquery must be aggregated")
141+
}
125142

126-
case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty =>
127-
128-
// Collect the columns from the subquery for further checking.
129-
var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains)
130-
131-
def checkAggregate(agg: Aggregate): Unit = {
132-
// Make sure correlated scalar subqueries contain one row for every outer row by
133-
// enforcing that they are aggregates which contain exactly one aggregate expressions.
134-
// The analyzer has already checked that subquery contained only one output column,
135-
// and added all the grouping expressions to the aggregate.
136-
val aggregates = agg.expressions.flatMap(_.collect {
137-
case a: AggregateExpression => a
138-
})
139-
if (aggregates.isEmpty) {
140-
failAnalysis("The output of a correlated scalar subquery must be aggregated")
143+
// SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
144+
// are not part of the correlated columns.
145+
val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references))
146+
val correlatedCols = AttributeSet(subqueryColumns)
147+
val invalidCols = groupByCols -- correlatedCols
148+
// GROUP BY columns must be a subset of columns in the predicates
149+
if (invalidCols.nonEmpty) {
150+
failAnalysis(
151+
"A GROUP BY clause in a scalar correlated subquery " +
152+
"cannot contain non-correlated columns: " +
153+
invalidCols.mkString(","))
154+
}
141155
}
142156

143-
// SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
144-
// are not part of the correlated columns.
145-
val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references))
146-
val correlatedCols = AttributeSet(subqueryColumns)
147-
val invalidCols = groupByCols -- correlatedCols
148-
// GROUP BY columns must be a subset of columns in the predicates
149-
if (invalidCols.nonEmpty) {
150-
failAnalysis(
151-
"A GROUP BY clause in a scalar correlated subquery " +
152-
"cannot contain non-correlated columns: " +
153-
invalidCols.mkString(","))
154-
}
155-
}
157+
// Skip subquery aliases added by the Analyzer and the SQLBuilder.
158+
// For projects, do the necessary mapping and skip to its child.
159+
def cleanQuery(p: LogicalPlan): LogicalPlan = p match {
160+
case s: SubqueryAlias => cleanQuery(s.child)
161+
case p: Project =>
162+
// SPARK-18814: Map any aliases to their AttributeReference children
163+
// for the checking in the Aggregate operators below this Project.
164+
subqueryColumns = subqueryColumns.map {
165+
xs => p.projectList.collectFirst {
166+
case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId =>
167+
child
168+
}.getOrElse(xs)
169+
}
156170

157-
// Skip subquery aliases added by the Analyzer and the SQLBuilder.
158-
// For projects, do the necessary mapping and skip to its child.
159-
def cleanQuery(p: LogicalPlan): LogicalPlan = p match {
160-
case s: SubqueryAlias => cleanQuery(s.child)
161-
case p: Project =>
162-
// SPARK-18814: Map any aliases to their AttributeReference children
163-
// for the checking in the Aggregate operators below this Project.
164-
subqueryColumns = subqueryColumns.map {
165-
xs => p.projectList.collectFirst {
166-
case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId =>
167-
child
168-
}.getOrElse(xs)
169-
}
171+
cleanQuery(p.child)
172+
case child => child
173+
}
170174

171-
cleanQuery(p.child)
172-
case child => child
175+
cleanQuery(query) match {
176+
case a: Aggregate => checkAggregate(a)
177+
case Filter(_, a: Aggregate) => checkAggregate(a)
178+
case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail")
179+
}
173180
}
181+
checkAnalysis(query)
182+
s
174183

175-
cleanQuery(query) match {
176-
case a: Aggregate => checkAggregate(a)
177-
case Filter(_, a: Aggregate) => checkAggregate(a)
178-
case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail")
179-
}
184+
case s: SubqueryExpression =>
185+
checkAnalysis(s.plan)
180186
s
181187
}
182188

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
-- The test file contains negative test cases
2+
-- of invalid queries where error messages are expected.
3+
4+
create temporary view t1 as select * from values
5+
(1, 2, 3)
6+
as t1(t1a, t1b, t1c);
7+
8+
create temporary view t2 as select * from values
9+
(1, 0, 1)
10+
as t2(t2a, t2b, t2c);
11+
12+
create temporary view t3 as select * from values
13+
(3, 1, 2)
14+
as t3(t3a, t3b, t3c);
15+
16+
-- TC 01.01
17+
-- The column t2b in the SELECT of the subquery is invalid
18+
-- because it is neither an aggregate function nor a GROUP BY column.
19+
select t1a, t2b
20+
from t1, t2
21+
where t1b = t2c
22+
and t2b = (select max(avg)
23+
from (select t2b, avg(t2b) avg
24+
from t2
25+
where t2a = t1.t1b
26+
)
27+
)
28+
;
29+
30+
-- TC 01.02
31+
-- Invalid due to the column t2b not part of the output from table t2.
32+
select *
33+
from t1
34+
where t1a in (select min(t2a)
35+
from t2
36+
group by t2c
37+
having t2c in (select max(t3c)
38+
from t3
39+
group by t3b
40+
having t3b > t2b ))
41+
;
42+
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
-- Automatically generated by SQLQueryTestSuite
2+
-- Number of queries: 5
3+
4+
5+
-- !query 0
6+
create temporary view t1 as select * from values
7+
(1, 2, 3)
8+
as t1(t1a, t1b, t1c)
9+
-- !query 0 schema
10+
struct<>
11+
-- !query 0 output
12+
13+
14+
15+
-- !query 1
16+
create temporary view t2 as select * from values
17+
(1, 0, 1)
18+
as t2(t2a, t2b, t2c)
19+
-- !query 1 schema
20+
struct<>
21+
-- !query 1 output
22+
23+
24+
25+
-- !query 2
26+
create temporary view t3 as select * from values
27+
(3, 1, 2)
28+
as t3(t3a, t3b, t3c)
29+
-- !query 2 schema
30+
struct<>
31+
-- !query 2 output
32+
33+
34+
35+
-- !query 3
36+
select t1a, t2b
37+
from t1, t2
38+
where t1b = t2c
39+
and t2b = (select max(avg)
40+
from (select t2b, avg(t2b) avg
41+
from t2
42+
where t2a = t1.t1b
43+
)
44+
)
45+
-- !query 3 schema
46+
struct<>
47+
-- !query 3 output
48+
org.apache.spark.sql.AnalysisException
49+
expression 't2.`t2b`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;
50+
51+
52+
-- !query 4
53+
select *
54+
from t1
55+
where t1a in (select min(t2a)
56+
from t2
57+
group by t2c
58+
having t2c in (select max(t3c)
59+
from t3
60+
group by t3b
61+
having t3b > t2b ))
62+
-- !query 4 schema
63+
struct<>
64+
-- !query 4 output
65+
org.apache.spark.sql.AnalysisException
66+
resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter predicate-subquery#x [(t2c#x = max(t3c)#x) && (t3b#x > t2b#x)];

sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,10 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
228228
} catch {
229229
case a: AnalysisException if a.plan.nonEmpty =>
230230
// Do not output the logical plan tree which contains expression IDs.
231-
(StructType(Seq.empty), Seq(a.getClass.getName, a.getSimpleMessage))
231+
// Also implement a crude way of masking expression IDs in the error message
232+
// with a generic pattern "###".
233+
(StructType(Seq.empty),
234+
Seq(a.getClass.getName, a.getSimpleMessage.replaceAll("#\\d+", "#x")))
232235
case NonFatal(e) =>
233236
// If there is an exception, put the exception class followed by the message.
234237
(StructType(Seq.empty), Seq(e.getClass.getName, e.getMessage))

0 commit comments

Comments
 (0)