Skip to content

Commit a6114d8

Browse files
maropucloud-fan
authored andcommitted
[SPARK-32638][SQL] Corrects references when adding aliases in WidenSetOperationTypes
### What changes were proposed in this pull request? This PR intends to fix a bug where references can be missing when adding aliases to widen data types in `WidenSetOperationTypes`. For example, ``` CREATE OR REPLACE TEMPORARY VIEW t3 AS VALUES (decimal(1)) tbl(v); SELECT t.v FROM ( SELECT v FROM t3 UNION ALL SELECT v + v AS v FROM t3 ) t; org.apache.spark.sql.AnalysisException: Resolved attribute(s) v#1 missing from v#3 in operator !Project [v#1]. Attribute(s) with the same name appear in the operation: v. Please check if the right attribute(s) are used.;; !Project [v#1] <------ the reference got missing +- SubqueryAlias t +- Union :- Project [cast(v#1 as decimal(11,0)) AS v#3] : +- Project [v#1] : +- SubqueryAlias t3 : +- SubqueryAlias tbl : +- LocalRelation [v#1] +- Project [v#2] +- Project [CheckOverflow((promote_precision(cast(v#1 as decimal(11,0))) + promote_precision(cast(v#1 as decimal(11,0)))), DecimalType(11,0), true) AS v#2] +- SubqueryAlias t3 +- SubqueryAlias tbl +- LocalRelation [v#1] ``` In the case, `WidenSetOperationTypes` added the alias `cast(v#1 as decimal(11,0)) AS v#3`, then the reference in the top `Project` got missing. This PR correct the reference (`exprId` and widen `dataType`) after adding aliases in the rule. ### Why are the changes needed? bugfixes ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests Closes #29485 from maropu/SPARK-32638. Authored-by: Takeshi Yamamuro <yamamuro@apache.org> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent ffd5227 commit a6114d8

File tree

9 files changed

+378
-134
lines changed

9 files changed

+378
-134
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 122 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,127 @@ object AnalysisContext {
123123
}
124124
}
125125

126+
object Analyzer {
127+
128+
/**
129+
* Rewrites a given `plan` recursively based on rewrite mappings from old plans to new ones.
130+
* This method also updates all the related references in the `plan` accordingly.
131+
*
132+
* @param plan to rewrite
133+
* @param rewritePlanMap has mappings from old plans to new ones for the given `plan`.
134+
* @return a rewritten plan and updated references related to a root node of
135+
* the given `plan` for rewriting it.
136+
*/
137+
def rewritePlan(plan: LogicalPlan, rewritePlanMap: Map[LogicalPlan, LogicalPlan])
138+
: (LogicalPlan, Seq[(Attribute, Attribute)]) = {
139+
if (plan.resolved) {
140+
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
141+
val newChildren = plan.children.map { child =>
142+
// If not, we'd rewrite child plan recursively until we find the
143+
// conflict node or reach the leaf node.
144+
val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap)
145+
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
146+
// `attrMapping` is not only used to replace the attributes of the current `plan`,
147+
// but also to be propagated to the parent plans of the current `plan`. Therefore,
148+
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
149+
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
150+
// used by those parent plans).
151+
(plan.outputSet ++ plan.references).contains(oldAttr)
152+
}
153+
newChild
154+
}
155+
156+
val newPlan = if (rewritePlanMap.contains(plan)) {
157+
rewritePlanMap(plan).withNewChildren(newChildren)
158+
} else {
159+
plan.withNewChildren(newChildren)
160+
}
161+
162+
assert(!attrMapping.groupBy(_._1.exprId)
163+
.exists(_._2.map(_._2.exprId).distinct.length > 1),
164+
"Found duplicate rewrite attributes")
165+
166+
val attributeRewrites = AttributeMap(attrMapping)
167+
// Using attrMapping from the children plans to rewrite their parent node.
168+
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
169+
val p = newPlan.transformExpressions {
170+
case a: Attribute =>
171+
updateAttr(a, attributeRewrites)
172+
case s: SubqueryExpression =>
173+
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites))
174+
}
175+
attrMapping ++= plan.output.zip(p.output)
176+
.filter { case (a1, a2) => a1.exprId != a2.exprId }
177+
p -> attrMapping
178+
} else {
179+
// Just passes through unresolved nodes
180+
plan.mapChildren {
181+
rewritePlan(_, rewritePlanMap)._1
182+
} -> Nil
183+
}
184+
}
185+
186+
private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
187+
val exprId = attrMap.getOrElse(attr, attr).exprId
188+
attr.withExprId(exprId)
189+
}
190+
191+
/**
192+
* The outer plan may have old references and the function below updates the
193+
* outer references to refer to the new attributes.
194+
*
195+
* For example (SQL):
196+
* {{{
197+
* SELECT * FROM t1
198+
* INTERSECT
199+
* SELECT * FROM t1
200+
* WHERE EXISTS (SELECT 1
201+
* FROM t2
202+
* WHERE t1.c1 = t2.c1)
203+
* }}}
204+
* Plan before resolveReference rule.
205+
* 'Intersect
206+
* :- Project [c1#245, c2#246]
207+
* : +- SubqueryAlias t1
208+
* : +- Relation[c1#245,c2#246] parquet
209+
* +- 'Project [*]
210+
* +- Filter exists#257 [c1#245]
211+
* : +- Project [1 AS 1#258]
212+
* : +- Filter (outer(c1#245) = c1#251)
213+
* : +- SubqueryAlias t2
214+
* : +- Relation[c1#251,c2#252] parquet
215+
* +- SubqueryAlias t1
216+
* +- Relation[c1#245,c2#246] parquet
217+
* Plan after the resolveReference rule.
218+
* Intersect
219+
* :- Project [c1#245, c2#246]
220+
* : +- SubqueryAlias t1
221+
* : +- Relation[c1#245,c2#246] parquet
222+
* +- Project [c1#259, c2#260]
223+
* +- Filter exists#257 [c1#259]
224+
* : +- Project [1 AS 1#258]
225+
* : +- Filter (outer(c1#259) = c1#251) => Updated
226+
* : +- SubqueryAlias t2
227+
* : +- Relation[c1#251,c2#252] parquet
228+
* +- SubqueryAlias t1
229+
* +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are rewritten.
230+
*/
231+
private def updateOuterReferencesInSubquery(
232+
plan: LogicalPlan,
233+
attrMap: AttributeMap[Attribute]): LogicalPlan = {
234+
AnalysisHelper.allowInvokingTransformsInAnalyzer {
235+
plan transformDown { case currentFragment =>
236+
currentFragment transformExpressions {
237+
case OuterReference(a: Attribute) =>
238+
OuterReference(updateAttr(a, attrMap))
239+
case s: SubqueryExpression =>
240+
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attrMap))
241+
}
242+
}
243+
}
244+
}
245+
}
246+
126247
/**
127248
* Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
128249
* [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]].
@@ -1255,109 +1376,7 @@ class Analyzer(
12551376
if (conflictPlans.isEmpty) {
12561377
right
12571378
} else {
1258-
rewritePlan(right, conflictPlans.toMap)._1
1259-
}
1260-
}
1261-
1262-
private def rewritePlan(plan: LogicalPlan, conflictPlanMap: Map[LogicalPlan, LogicalPlan])
1263-
: (LogicalPlan, Seq[(Attribute, Attribute)]) = {
1264-
if (conflictPlanMap.contains(plan)) {
1265-
// If the plan is the one that conflict the with left one, we'd
1266-
// just replace it with the new plan and collect the rewrite
1267-
// attributes for the parent node.
1268-
val newRelation = conflictPlanMap(plan)
1269-
newRelation -> plan.output.zip(newRelation.output)
1270-
} else {
1271-
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
1272-
val newPlan = plan.mapChildren { child =>
1273-
// If not, we'd rewrite child plan recursively until we find the
1274-
// conflict node or reach the leaf node.
1275-
val (newChild, childAttrMapping) = rewritePlan(child, conflictPlanMap)
1276-
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
1277-
// `attrMapping` is not only used to replace the attributes of the current `plan`,
1278-
// but also to be propagated to the parent plans of the current `plan`. Therefore,
1279-
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
1280-
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
1281-
// used by those parent plans).
1282-
(plan.outputSet ++ plan.references).contains(oldAttr)
1283-
}
1284-
newChild
1285-
}
1286-
1287-
if (attrMapping.isEmpty) {
1288-
newPlan -> attrMapping.toSeq
1289-
} else {
1290-
assert(!attrMapping.groupBy(_._1.exprId)
1291-
.exists(_._2.map(_._2.exprId).distinct.length > 1),
1292-
"Found duplicate rewrite attributes")
1293-
val attributeRewrites = AttributeMap(attrMapping.toSeq)
1294-
// Using attrMapping from the children plans to rewrite their parent node.
1295-
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
1296-
newPlan.transformExpressions {
1297-
case a: Attribute =>
1298-
dedupAttr(a, attributeRewrites)
1299-
case s: SubqueryExpression =>
1300-
s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites))
1301-
} -> attrMapping.toSeq
1302-
}
1303-
}
1304-
}
1305-
1306-
private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
1307-
val exprId = attrMap.getOrElse(attr, attr).exprId
1308-
attr.withExprId(exprId)
1309-
}
1310-
1311-
/**
1312-
* The outer plan may have been de-duplicated and the function below updates the
1313-
* outer references to refer to the de-duplicated attributes.
1314-
*
1315-
* For example (SQL):
1316-
* {{{
1317-
* SELECT * FROM t1
1318-
* INTERSECT
1319-
* SELECT * FROM t1
1320-
* WHERE EXISTS (SELECT 1
1321-
* FROM t2
1322-
* WHERE t1.c1 = t2.c1)
1323-
* }}}
1324-
* Plan before resolveReference rule.
1325-
* 'Intersect
1326-
* :- Project [c1#245, c2#246]
1327-
* : +- SubqueryAlias t1
1328-
* : +- Relation[c1#245,c2#246] parquet
1329-
* +- 'Project [*]
1330-
* +- Filter exists#257 [c1#245]
1331-
* : +- Project [1 AS 1#258]
1332-
* : +- Filter (outer(c1#245) = c1#251)
1333-
* : +- SubqueryAlias t2
1334-
* : +- Relation[c1#251,c2#252] parquet
1335-
* +- SubqueryAlias t1
1336-
* +- Relation[c1#245,c2#246] parquet
1337-
* Plan after the resolveReference rule.
1338-
* Intersect
1339-
* :- Project [c1#245, c2#246]
1340-
* : +- SubqueryAlias t1
1341-
* : +- Relation[c1#245,c2#246] parquet
1342-
* +- Project [c1#259, c2#260]
1343-
* +- Filter exists#257 [c1#259]
1344-
* : +- Project [1 AS 1#258]
1345-
* : +- Filter (outer(c1#259) = c1#251) => Updated
1346-
* : +- SubqueryAlias t2
1347-
* : +- Relation[c1#251,c2#252] parquet
1348-
* +- SubqueryAlias t1
1349-
* +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated.
1350-
*/
1351-
private def dedupOuterReferencesInSubquery(
1352-
plan: LogicalPlan,
1353-
attrMap: AttributeMap[Attribute]): LogicalPlan = {
1354-
plan transformDown { case currentFragment =>
1355-
currentFragment transformExpressions {
1356-
case OuterReference(a: Attribute) =>
1357-
OuterReference(dedupAttr(a, attrMap))
1358-
case s: SubqueryExpression =>
1359-
s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap))
1360-
}
1379+
Analyzer.rewritePlan(right, conflictPlans.toMap)._1
13611380
}
13621381
}
13631382

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -326,29 +326,53 @@ object TypeCoercion {
326326
*
327327
* This rule is only applied to Union/Except/Intersect
328328
*/
329-
object WidenSetOperationTypes extends Rule[LogicalPlan] {
330-
331-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
332-
case s @ Except(left, right, isAll) if s.childrenResolved &&
333-
left.output.length == right.output.length && !s.resolved =>
334-
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
335-
assert(newChildren.length == 2)
336-
Except(newChildren.head, newChildren.last, isAll)
337-
338-
case s @ Intersect(left, right, isAll) if s.childrenResolved &&
339-
left.output.length == right.output.length && !s.resolved =>
340-
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
341-
assert(newChildren.length == 2)
342-
Intersect(newChildren.head, newChildren.last, isAll)
343-
344-
case s: Union if s.childrenResolved && !s.byName &&
329+
object WidenSetOperationTypes extends TypeCoercionRule {
330+
331+
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = {
332+
val rewritePlanMap = mutable.ArrayBuffer[(LogicalPlan, LogicalPlan)]()
333+
val newPlan = plan resolveOperatorsUp {
334+
case s @ Except(left, right, isAll) if s.childrenResolved &&
335+
left.output.length == right.output.length && !s.resolved =>
336+
val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil)
337+
if (newChildren.nonEmpty) {
338+
rewritePlanMap ++= newChildren
339+
Except(newChildren.head._1, newChildren.last._1, isAll)
340+
} else {
341+
s
342+
}
343+
344+
case s @ Intersect(left, right, isAll) if s.childrenResolved &&
345+
left.output.length == right.output.length && !s.resolved =>
346+
val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil)
347+
if (newChildren.nonEmpty) {
348+
rewritePlanMap ++= newChildren
349+
Intersect(newChildren.head._1, newChildren.last._1, isAll)
350+
} else {
351+
s
352+
}
353+
354+
case s: Union if s.childrenResolved && !s.byName &&
345355
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
346-
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
347-
s.copy(children = newChildren)
356+
val newChildren = buildNewChildrenWithWiderTypes(s.children)
357+
if (newChildren.nonEmpty) {
358+
rewritePlanMap ++= newChildren
359+
s.copy(children = newChildren.map(_._1))
360+
} else {
361+
s
362+
}
363+
}
364+
365+
if (rewritePlanMap.nonEmpty) {
366+
assert(!plan.fastEquals(newPlan))
367+
Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1
368+
} else {
369+
plan
370+
}
348371
}
349372

350373
/** Build new children with the widest types for each attribute among all the children */
351-
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
374+
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan])
375+
: Seq[(LogicalPlan, LogicalPlan)] = {
352376
require(children.forall(_.output.length == children.head.output.length))
353377

354378
// Get a sequence of data types, each of which is the widest type of this specific attribute
@@ -360,8 +384,7 @@ object TypeCoercion {
360384
// Add an extra Project if the targetTypes are different from the original types.
361385
children.map(widenTypes(_, targetTypes))
362386
} else {
363-
// Unable to find a target type to widen, then just return the original set.
364-
children
387+
Nil
365388
}
366389
}
367390

@@ -385,12 +408,16 @@ object TypeCoercion {
385408
}
386409

387410
/** Given a plan, add an extra project on top to widen some columns' data types. */
388-
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = {
411+
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType])
412+
: (LogicalPlan, LogicalPlan) = {
389413
val casted = plan.output.zip(targetTypes).map {
390-
case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
391-
case (e, _) => e
392-
}
393-
Project(casted, plan)
414+
case (e, dt) if e.dataType != dt =>
415+
val alias = Alias(Cast(e, dt), e.name)(exprId = e.exprId)
416+
alias -> alias.newInstance()
417+
case (e, _) =>
418+
e -> e
419+
}.unzip
420+
Project(casted._1, plan) -> Project(casted._2, plan)
394421
}
395422
}
396423

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@ import java.sql.Timestamp
2121

2222
import org.apache.spark.sql.catalyst.analysis.TypeCoercion._
2323
import org.apache.spark.sql.catalyst.dsl.expressions._
24+
import org.apache.spark.sql.catalyst.dsl.plans._
2425
import org.apache.spark.sql.catalyst.expressions._
25-
import org.apache.spark.sql.catalyst.plans.PlanTest
2626
import org.apache.spark.sql.catalyst.plans.logical._
2727
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
2828
import org.apache.spark.sql.internal.SQLConf
2929
import org.apache.spark.sql.types._
30-
import org.apache.spark.unsafe.types.CalendarInterval
3130

3231
class TypeCoercionSuite extends AnalysisTest {
3332
import TypeCoercionSuite._
@@ -1417,6 +1416,20 @@ class TypeCoercionSuite extends AnalysisTest {
14171416
}
14181417
}
14191418

1419+
test("SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes") {
1420+
val t1 = LocalRelation(AttributeReference("v", DecimalType(10, 0))())
1421+
val t2 = LocalRelation(AttributeReference("v", DecimalType(11, 0))())
1422+
val p1 = t1.select(t1.output.head)
1423+
val p2 = t2.select(t2.output.head)
1424+
val union = p1.union(p2)
1425+
val wp1 = widenSetOperationTypes(union.select(p1.output.head))
1426+
assert(wp1.isInstanceOf[Project])
1427+
assert(wp1.missingInput.isEmpty)
1428+
val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union))
1429+
assert(wp2.isInstanceOf[Aggregate])
1430+
assert(wp2.missingInput.isEmpty)
1431+
}
1432+
14201433
/**
14211434
* There are rules that need to not fire before child expressions get resolved.
14221435
* We use this test to make sure those rules do not fire early.

0 commit comments

Comments
 (0)