Skip to content

Commit 3f20f14

Browse files
cloud-fanmaropu
andcommitted
[SPARK-32638][SQL][3.0] Corrects references when adding aliases in WidenSetOperationTypes
### What changes were proposed in this pull request? This PR intends to fix a bug where references can be missing when adding aliases to widen data types in `WidenSetOperationTypes`. For example, ``` CREATE OR REPLACE TEMPORARY VIEW t3 AS VALUES (decimal(1)) tbl(v); SELECT t.v FROM ( SELECT v FROM t3 UNION ALL SELECT v + v AS v FROM t3 ) t; org.apache.spark.sql.AnalysisException: Resolved attribute(s) v#1 missing from v#3 in operator !Project [v#1]. Attribute(s) with the same name appear in the operation: v. Please check if the right attribute(s) are used.;; !Project [v#1] <------ the reference got missing +- SubqueryAlias t +- Union :- Project [cast(v#1 as decimal(11,0)) AS v#3] : +- Project [v#1] : +- SubqueryAlias t3 : +- SubqueryAlias tbl : +- LocalRelation [v#1] +- Project [v#2] +- Project [CheckOverflow((promote_precision(cast(v#1 as decimal(11,0))) + promote_precision(cast(v#1 as decimal(11,0)))), DecimalType(11,0), true) AS v#2] +- SubqueryAlias t3 +- SubqueryAlias tbl +- LocalRelation [v#1] ``` In the case, `WidenSetOperationTypes` added the alias `cast(v#1 as decimal(11,0)) AS v#3`, then the reference in the top `Project` got missing. This PR correct the reference (`exprId` and widen `dataType`) after adding aliases in the rule. This backport for 3.0 comes from #29485 and #29643 ### Why are the changes needed? bugfixes ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests Closes #29680 from maropu/SPARK-32638-BRANCH3.0. Lead-authored-by: Wenchen Fan <wenchen@databricks.com> Co-authored-by: Takeshi Yamamuro <yamamuro@apache.org> Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
1 parent 8c0b9cb commit 3f20f14

File tree

11 files changed

+348
-129
lines changed

11 files changed

+348
-129
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 7 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,108 +1239,13 @@ class Analyzer(
12391239
if (conflictPlans.isEmpty) {
12401240
right
12411241
} else {
1242-
rewritePlan(right, conflictPlans.toMap)._1
1243-
}
1244-
}
1245-
1246-
private def rewritePlan(plan: LogicalPlan, conflictPlanMap: Map[LogicalPlan, LogicalPlan])
1247-
: (LogicalPlan, Seq[(Attribute, Attribute)]) = {
1248-
if (conflictPlanMap.contains(plan)) {
1249-
// If the plan is the one that conflict the with left one, we'd
1250-
// just replace it with the new plan and collect the rewrite
1251-
// attributes for the parent node.
1252-
val newRelation = conflictPlanMap(plan)
1253-
newRelation -> plan.output.zip(newRelation.output)
1254-
} else {
1255-
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
1256-
val newPlan = plan.mapChildren { child =>
1257-
// If not, we'd rewrite child plan recursively until we find the
1258-
// conflict node or reach the leaf node.
1259-
val (newChild, childAttrMapping) = rewritePlan(child, conflictPlanMap)
1260-
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
1261-
// `attrMapping` is not only used to replace the attributes of the current `plan`,
1262-
// but also to be propagated to the parent plans of the current `plan`. Therefore,
1263-
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
1264-
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
1265-
// used by those parent plans).
1266-
(plan.outputSet ++ plan.references).contains(oldAttr)
1267-
}
1268-
newChild
1269-
}
1270-
1271-
if (attrMapping.isEmpty) {
1272-
newPlan -> attrMapping
1273-
} else {
1274-
assert(!attrMapping.groupBy(_._1.exprId)
1275-
.exists(_._2.map(_._2.exprId).distinct.length > 1),
1276-
"Found duplicate rewrite attributes")
1277-
val attributeRewrites = AttributeMap(attrMapping)
1278-
// Using attrMapping from the children plans to rewrite their parent node.
1279-
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
1280-
newPlan.transformExpressions {
1281-
case a: Attribute =>
1282-
dedupAttr(a, attributeRewrites)
1283-
case s: SubqueryExpression =>
1284-
s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites))
1285-
} -> attrMapping
1286-
}
1287-
}
1288-
}
1289-
1290-
private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
1291-
val exprId = attrMap.getOrElse(attr, attr).exprId
1292-
attr.withExprId(exprId)
1293-
}
1294-
1295-
/**
1296-
* The outer plan may have been de-duplicated and the function below updates the
1297-
* outer references to refer to the de-duplicated attributes.
1298-
*
1299-
* For example (SQL):
1300-
* {{{
1301-
* SELECT * FROM t1
1302-
* INTERSECT
1303-
* SELECT * FROM t1
1304-
* WHERE EXISTS (SELECT 1
1305-
* FROM t2
1306-
* WHERE t1.c1 = t2.c1)
1307-
* }}}
1308-
* Plan before resolveReference rule.
1309-
* 'Intersect
1310-
* :- Project [c1#245, c2#246]
1311-
* : +- SubqueryAlias t1
1312-
* : +- Relation[c1#245,c2#246] parquet
1313-
* +- 'Project [*]
1314-
* +- Filter exists#257 [c1#245]
1315-
* : +- Project [1 AS 1#258]
1316-
* : +- Filter (outer(c1#245) = c1#251)
1317-
* : +- SubqueryAlias t2
1318-
* : +- Relation[c1#251,c2#252] parquet
1319-
* +- SubqueryAlias t1
1320-
* +- Relation[c1#245,c2#246] parquet
1321-
* Plan after the resolveReference rule.
1322-
* Intersect
1323-
* :- Project [c1#245, c2#246]
1324-
* : +- SubqueryAlias t1
1325-
* : +- Relation[c1#245,c2#246] parquet
1326-
* +- Project [c1#259, c2#260]
1327-
* +- Filter exists#257 [c1#259]
1328-
* : +- Project [1 AS 1#258]
1329-
* : +- Filter (outer(c1#259) = c1#251) => Updated
1330-
* : +- SubqueryAlias t2
1331-
* : +- Relation[c1#251,c2#252] parquet
1332-
* +- SubqueryAlias t1
1333-
* +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated.
1334-
*/
1335-
private def dedupOuterReferencesInSubquery(
1336-
plan: LogicalPlan,
1337-
attrMap: AttributeMap[Attribute]): LogicalPlan = {
1338-
plan transformDown { case currentFragment =>
1339-
currentFragment transformExpressions {
1340-
case OuterReference(a: Attribute) =>
1341-
OuterReference(dedupAttr(a, attrMap))
1342-
case s: SubqueryExpression =>
1343-
s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap))
1242+
val planMapping = conflictPlans.toMap
1243+
right.transformUpWithNewOutput {
1244+
case oldPlan =>
1245+
val newPlanOpt = planMapping.get(oldPlan)
1246+
newPlanOpt.map { newPlan =>
1247+
newPlan -> oldPlan.output.zip(newPlan.output)
1248+
}.getOrElse(oldPlan -> Nil)
13441249
}
13451250
}
13461251
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -326,25 +326,42 @@ object TypeCoercion {
326326
*
327327
* This rule is only applied to Union/Except/Intersect
328328
*/
329-
object WidenSetOperationTypes extends Rule[LogicalPlan] {
330-
331-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
332-
case s @ Except(left, right, isAll) if s.childrenResolved &&
333-
left.output.length == right.output.length && !s.resolved =>
334-
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
335-
assert(newChildren.length == 2)
336-
Except(newChildren.head, newChildren.last, isAll)
337-
338-
case s @ Intersect(left, right, isAll) if s.childrenResolved &&
339-
left.output.length == right.output.length && !s.resolved =>
340-
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
341-
assert(newChildren.length == 2)
342-
Intersect(newChildren.head, newChildren.last, isAll)
343-
344-
case s: Union if s.childrenResolved &&
329+
object WidenSetOperationTypes extends TypeCoercionRule {
330+
331+
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = {
332+
plan resolveOperatorsUpWithNewOutput {
333+
case s @ Except(left, right, isAll) if s.childrenResolved &&
334+
left.output.length == right.output.length && !s.resolved =>
335+
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
336+
if (newChildren.isEmpty) {
337+
s -> Nil
338+
} else {
339+
assert(newChildren.length == 2)
340+
val attrMapping = left.output.zip(newChildren.head.output)
341+
Except(newChildren.head, newChildren.last, isAll) -> attrMapping
342+
}
343+
344+
case s @ Intersect(left, right, isAll) if s.childrenResolved &&
345+
left.output.length == right.output.length && !s.resolved =>
346+
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
347+
if (newChildren.isEmpty) {
348+
s -> Nil
349+
} else {
350+
assert(newChildren.length == 2)
351+
val attrMapping = left.output.zip(newChildren.head.output)
352+
Intersect(newChildren.head, newChildren.last, isAll) -> attrMapping
353+
}
354+
355+
case s: Union if s.childrenResolved &&
345356
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
346-
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
347-
s.makeCopy(Array(newChildren))
357+
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
358+
if (newChildren.isEmpty) {
359+
s -> Nil
360+
} else {
361+
val attrMapping = s.children.head.output.zip(newChildren.head.output)
362+
s.copy(children = newChildren) -> attrMapping
363+
}
364+
}
348365
}
349366

350367
/** Build new children with the widest types for each attribute among all the children */
@@ -360,8 +377,7 @@ object TypeCoercion {
360377
// Add an extra Project if the targetTypes are different from the original types.
361378
children.map(widenTypes(_, targetTypes))
362379
} else {
363-
// Unable to find a target type to widen, then just return the original set.
364-
children
380+
Nil
365381
}
366382
}
367383

@@ -387,7 +403,8 @@ object TypeCoercion {
387403
/** Given a plan, add an extra project on top to widen some columns' data types. */
388404
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = {
389405
val casted = plan.output.zip(targetTypes).map {
390-
case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
406+
case (e, dt) if e.dataType != dt =>
407+
Alias(Cast(e, dt, Some(SQLConf.get.sessionLocalTimeZone)), e.name)()
391408
case (e, _) => e
392409
}
393410
Project(casted, plan)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.plans
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.spark.sql.AnalysisException
2123
import org.apache.spark.sql.catalyst.expressions._
2224
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode, TreeNodeTag}
@@ -168,6 +170,89 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
168170
}.toSeq
169171
}
170172

173+
/**
174+
* A variant of `transformUp`, which takes care of the case that the rule replaces a plan node
175+
* with a new one that has different output expr IDs, by updating the attribute references in
176+
* the parent nodes accordingly.
177+
*
178+
* @param rule the function to transform plan nodes, and return new nodes with attributes mapping
179+
* from old attributes to new attributes. The attribute mapping will be used to
180+
* rewrite attribute references in the parent nodes.
181+
* @param skipCond a boolean condition to indicate if we can skip transforming a plan node to save
182+
* time.
183+
*/
184+
def transformUpWithNewOutput(
185+
rule: PartialFunction[PlanType, (PlanType, Seq[(Attribute, Attribute)])],
186+
skipCond: PlanType => Boolean = _ => false): PlanType = {
187+
def rewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = {
188+
if (skipCond(plan)) {
189+
plan -> Nil
190+
} else {
191+
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
192+
var newPlan = plan.mapChildren { child =>
193+
val (newChild, childAttrMapping) = rewrite(child)
194+
attrMapping ++= childAttrMapping
195+
newChild
196+
}
197+
198+
val attrMappingForCurrentPlan = attrMapping.filter {
199+
// The `attrMappingForCurrentPlan` is used to replace the attributes of the
200+
// current `plan`, so the `oldAttr` must be part of `plan.references`.
201+
case (oldAttr, _) => plan.references.contains(oldAttr)
202+
}
203+
204+
val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) {
205+
rule.applyOrElse(newPlan, (plan: PlanType) => plan -> Nil)
206+
}
207+
newPlan = planAfterRule
208+
209+
if (attrMappingForCurrentPlan.nonEmpty) {
210+
assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId)
211+
.exists(_._2.map(_._2.exprId).distinct.length > 1),
212+
"Found duplicate rewrite attributes")
213+
214+
val attributeRewrites = AttributeMap(attrMappingForCurrentPlan)
215+
// Using attrMapping from the children plans to rewrite their parent node.
216+
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
217+
newPlan = newPlan.transformExpressions {
218+
case a: AttributeReference =>
219+
updateAttr(a, attributeRewrites)
220+
case pe: PlanExpression[PlanType] =>
221+
pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attributeRewrites))
222+
}
223+
}
224+
225+
attrMapping ++= newAttrMapping.filter {
226+
case (a1, a2) => a1.exprId != a2.exprId
227+
}
228+
newPlan -> attrMapping
229+
}
230+
}
231+
rewrite(this)._1
232+
}
233+
234+
private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
235+
val exprId = attrMap.getOrElse(attr, attr).exprId
236+
attr.withExprId(exprId)
237+
}
238+
239+
/**
240+
* The outer plan may have old references and the function below updates the
241+
* outer references to refer to the new attributes.
242+
*/
243+
private def updateOuterReferencesInSubquery(
244+
plan: PlanType,
245+
attrMap: AttributeMap[Attribute]): PlanType = {
246+
plan.transformDown { case currentFragment =>
247+
currentFragment.transformExpressions {
248+
case OuterReference(a: AttributeReference) =>
249+
OuterReference(updateAttr(a, attrMap))
250+
case pe: PlanExpression[PlanType] =>
251+
pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap))
252+
}
253+
}
254+
}
255+
171256
lazy val schema: StructType = StructType.fromAttributes(output)
172257

173258
/** Returns the output schema in the tree format. */

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.sql.catalyst.analysis.CheckAnalysis
21-
import org.apache.spark.sql.catalyst.expressions.Expression
21+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
2222
import org.apache.spark.sql.catalyst.plans.QueryPlan
2323
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode}
2424
import org.apache.spark.util.Utils
@@ -120,6 +120,19 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
120120
}
121121
}
122122

123+
/**
124+
* A variant of `transformUpWithNewOutput`, which skips touching already analyzed plan.
125+
*/
126+
def resolveOperatorsUpWithNewOutput(
127+
rule: PartialFunction[LogicalPlan, (LogicalPlan, Seq[(Attribute, Attribute)])])
128+
: LogicalPlan = {
129+
if (!analyzed) {
130+
transformUpWithNewOutput(rule, skipCond = _.analyzed)
131+
} else {
132+
self
133+
}
134+
}
135+
123136
/**
124137
* Recursively transforms the expressions of a tree, skipping nodes that have already
125138
* been analyzed.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@ import java.sql.Timestamp
2121

2222
import org.apache.spark.sql.catalyst.analysis.TypeCoercion._
2323
import org.apache.spark.sql.catalyst.dsl.expressions._
24+
import org.apache.spark.sql.catalyst.dsl.plans._
2425
import org.apache.spark.sql.catalyst.expressions._
25-
import org.apache.spark.sql.catalyst.plans.PlanTest
2626
import org.apache.spark.sql.catalyst.plans.logical._
2727
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
2828
import org.apache.spark.sql.internal.SQLConf
2929
import org.apache.spark.sql.types._
30-
import org.apache.spark.unsafe.types.CalendarInterval
3130

3231
class TypeCoercionSuite extends AnalysisTest {
3332
import TypeCoercionSuite._
@@ -1417,6 +1416,21 @@ class TypeCoercionSuite extends AnalysisTest {
14171416
}
14181417
}
14191418

1419+
test("SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes") {
1420+
val t1 = LocalRelation(AttributeReference("v", DecimalType(10, 0))())
1421+
val t2 = LocalRelation(AttributeReference("v", DecimalType(11, 0))())
1422+
val p1 = t1.select(t1.output.head).as("p1")
1423+
val p2 = t2.select(t2.output.head).as("p2")
1424+
val union = p1.union(p2)
1425+
val wp1 = widenSetOperationTypes(union.select(p1.output.head, $"p2.v"))
1426+
assert(wp1.isInstanceOf[Project])
1427+
// The attribute `p1.output.head` should be replaced in the root `Project`.
1428+
assert(wp1.expressions.forall(_.find(_ == p1.output.head).isEmpty))
1429+
val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union))
1430+
assert(wp2.isInstanceOf[Aggregate])
1431+
assert(wp2.missingInput.isEmpty)
1432+
}
1433+
14201434
/**
14211435
* There are rules that need to not fire before child expressions get resolved.
14221436
* We use this test to make sure those rules do not fire early.

sql/core/src/test/resources/sql-tests/inputs/except.sql

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,22 @@ FROM t1
5555
WHERE t1.v >= (SELECT min(t2.v)
5656
FROM t2
5757
WHERE t2.k = t1.k);
58+
59+
-- SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes
60+
CREATE OR REPLACE TEMPORARY VIEW t3 AS VALUES (decimal(1)) tbl(v);
61+
SELECT t.v FROM (
62+
SELECT v FROM t3
63+
EXCEPT
64+
SELECT v + v AS v FROM t3
65+
) t;
66+
67+
SELECT SUM(t.v) FROM (
68+
SELECT v FROM t3
69+
EXCEPT
70+
SELECT v + v AS v FROM t3
71+
) t;
72+
73+
-- Clean-up
74+
DROP VIEW IF EXISTS t1;
75+
DROP VIEW IF EXISTS t2;
76+
DROP VIEW IF EXISTS t3;

0 commit comments

Comments
 (0)