Skip to content

Commit 5cce482

Browse files
committed
move the plan rewrite methods to QueryPlan
1 parent a6114d8 commit 5cce482

File tree

5 files changed

+132
-162
lines changed

5 files changed

+132
-162
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 8 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -123,127 +123,6 @@ object AnalysisContext {
123123
}
124124
}
125125

126-
object Analyzer {
127-
128-
/**
129-
* Rewrites a given `plan` recursively based on rewrite mappings from old plans to new ones.
130-
* This method also updates all the related references in the `plan` accordingly.
131-
*
132-
* @param plan to rewrite
133-
* @param rewritePlanMap has mappings from old plans to new ones for the given `plan`.
134-
* @return a rewritten plan and updated references related to a root node of
135-
* the given `plan` for rewriting it.
136-
*/
137-
def rewritePlan(plan: LogicalPlan, rewritePlanMap: Map[LogicalPlan, LogicalPlan])
138-
: (LogicalPlan, Seq[(Attribute, Attribute)]) = {
139-
if (plan.resolved) {
140-
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
141-
val newChildren = plan.children.map { child =>
142-
// If not, we'd rewrite child plan recursively until we find the
143-
// conflict node or reach the leaf node.
144-
val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap)
145-
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
146-
// `attrMapping` is not only used to replace the attributes of the current `plan`,
147-
// but also to be propagated to the parent plans of the current `plan`. Therefore,
148-
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
149-
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
150-
// used by those parent plans).
151-
(plan.outputSet ++ plan.references).contains(oldAttr)
152-
}
153-
newChild
154-
}
155-
156-
val newPlan = if (rewritePlanMap.contains(plan)) {
157-
rewritePlanMap(plan).withNewChildren(newChildren)
158-
} else {
159-
plan.withNewChildren(newChildren)
160-
}
161-
162-
assert(!attrMapping.groupBy(_._1.exprId)
163-
.exists(_._2.map(_._2.exprId).distinct.length > 1),
164-
"Found duplicate rewrite attributes")
165-
166-
val attributeRewrites = AttributeMap(attrMapping)
167-
// Using attrMapping from the children plans to rewrite their parent node.
168-
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
169-
val p = newPlan.transformExpressions {
170-
case a: Attribute =>
171-
updateAttr(a, attributeRewrites)
172-
case s: SubqueryExpression =>
173-
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites))
174-
}
175-
attrMapping ++= plan.output.zip(p.output)
176-
.filter { case (a1, a2) => a1.exprId != a2.exprId }
177-
p -> attrMapping
178-
} else {
179-
// Just passes through unresolved nodes
180-
plan.mapChildren {
181-
rewritePlan(_, rewritePlanMap)._1
182-
} -> Nil
183-
}
184-
}
185-
186-
private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
187-
val exprId = attrMap.getOrElse(attr, attr).exprId
188-
attr.withExprId(exprId)
189-
}
190-
191-
/**
192-
* The outer plan may have old references and the function below updates the
193-
* outer references to refer to the new attributes.
194-
*
195-
* For example (SQL):
196-
* {{{
197-
* SELECT * FROM t1
198-
* INTERSECT
199-
* SELECT * FROM t1
200-
* WHERE EXISTS (SELECT 1
201-
* FROM t2
202-
* WHERE t1.c1 = t2.c1)
203-
* }}}
204-
* Plan before resolveReference rule.
205-
* 'Intersect
206-
* :- Project [c1#245, c2#246]
207-
* : +- SubqueryAlias t1
208-
* : +- Relation[c1#245,c2#246] parquet
209-
* +- 'Project [*]
210-
* +- Filter exists#257 [c1#245]
211-
* : +- Project [1 AS 1#258]
212-
* : +- Filter (outer(c1#245) = c1#251)
213-
* : +- SubqueryAlias t2
214-
* : +- Relation[c1#251,c2#252] parquet
215-
* +- SubqueryAlias t1
216-
* +- Relation[c1#245,c2#246] parquet
217-
* Plan after the resolveReference rule.
218-
* Intersect
219-
* :- Project [c1#245, c2#246]
220-
* : +- SubqueryAlias t1
221-
* : +- Relation[c1#245,c2#246] parquet
222-
* +- Project [c1#259, c2#260]
223-
* +- Filter exists#257 [c1#259]
224-
* : +- Project [1 AS 1#258]
225-
* : +- Filter (outer(c1#259) = c1#251) => Updated
226-
* : +- SubqueryAlias t2
227-
* : +- Relation[c1#251,c2#252] parquet
228-
* +- SubqueryAlias t1
229-
* +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are rewritten.
230-
*/
231-
private def updateOuterReferencesInSubquery(
232-
plan: LogicalPlan,
233-
attrMap: AttributeMap[Attribute]): LogicalPlan = {
234-
AnalysisHelper.allowInvokingTransformsInAnalyzer {
235-
plan transformDown { case currentFragment =>
236-
currentFragment transformExpressions {
237-
case OuterReference(a: Attribute) =>
238-
OuterReference(updateAttr(a, attrMap))
239-
case s: SubqueryExpression =>
240-
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attrMap))
241-
}
242-
}
243-
}
244-
}
245-
}
246-
247126
/**
248127
* Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
249128
* [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]].
@@ -1376,7 +1255,14 @@ class Analyzer(
13761255
if (conflictPlans.isEmpty) {
13771256
right
13781257
} else {
1379-
Analyzer.rewritePlan(right, conflictPlans.toMap)._1
1258+
val planMapping = conflictPlans.toMap
1259+
right.transformUpWithNewOutput {
1260+
case oldPlan =>
1261+
val newPlanOpt = planMapping.get(oldPlan)
1262+
newPlanOpt.map { newPlan =>
1263+
newPlan -> oldPlan.output.zip(newPlan.output)
1264+
}.getOrElse(oldPlan -> Nil)
1265+
}
13801266
}
13811267
}
13821268

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -329,50 +329,43 @@ object TypeCoercion {
329329
object WidenSetOperationTypes extends TypeCoercionRule {
330330

331331
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = {
332-
val rewritePlanMap = mutable.ArrayBuffer[(LogicalPlan, LogicalPlan)]()
333-
val newPlan = plan resolveOperatorsUp {
332+
plan resolveOperatorsUpWithNewOutput {
334333
case s @ Except(left, right, isAll) if s.childrenResolved &&
335334
left.output.length == right.output.length && !s.resolved =>
336-
val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil)
337-
if (newChildren.nonEmpty) {
338-
rewritePlanMap ++= newChildren
339-
Except(newChildren.head._1, newChildren.last._1, isAll)
335+
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
336+
if (newChildren.isEmpty) {
337+
s -> Nil
340338
} else {
341-
s
339+
assert(newChildren.length == 2)
340+
val attrMapping = left.output.zip(newChildren.head.output)
341+
Except(newChildren.head, newChildren.last, isAll) -> attrMapping
342342
}
343343

344344
case s @ Intersect(left, right, isAll) if s.childrenResolved &&
345345
left.output.length == right.output.length && !s.resolved =>
346-
val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil)
347-
if (newChildren.nonEmpty) {
348-
rewritePlanMap ++= newChildren
349-
Intersect(newChildren.head._1, newChildren.last._1, isAll)
346+
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
347+
if (newChildren.isEmpty) {
348+
s -> Nil
350349
} else {
351-
s
350+
assert(newChildren.length == 2)
351+
val attrMapping = left.output.zip(newChildren.head.output)
352+
Intersect(newChildren.head, newChildren.last, isAll) -> attrMapping
352353
}
353354

354355
case s: Union if s.childrenResolved && !s.byName &&
355356
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
356-
val newChildren = buildNewChildrenWithWiderTypes(s.children)
357-
if (newChildren.nonEmpty) {
358-
rewritePlanMap ++= newChildren
359-
s.copy(children = newChildren.map(_._1))
357+
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
358+
if (newChildren.isEmpty) {
359+
s -> Nil
360360
} else {
361-
s
361+
val attrMapping = s.children.head.output.zip(newChildren.head.output)
362+
s.copy(children = newChildren) -> attrMapping
362363
}
363364
}
364-
365-
if (rewritePlanMap.nonEmpty) {
366-
assert(!plan.fastEquals(newPlan))
367-
Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1
368-
} else {
369-
plan
370-
}
371365
}
372366

373367
/** Build new children with the widest types for each attribute among all the children */
374-
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan])
375-
: Seq[(LogicalPlan, LogicalPlan)] = {
368+
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
376369
require(children.forall(_.output.length == children.head.output.length))
377370

378371
// Get a sequence of data types, each of which is the widest type of this specific attribute
@@ -408,16 +401,12 @@ object TypeCoercion {
408401
}
409402

410403
/** Given a plan, add an extra project on top to widen some columns' data types. */
411-
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType])
412-
: (LogicalPlan, LogicalPlan) = {
404+
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = {
413405
val casted = plan.output.zip(targetTypes).map {
414-
case (e, dt) if e.dataType != dt =>
415-
val alias = Alias(Cast(e, dt), e.name)(exprId = e.exprId)
416-
alias -> alias.newInstance()
417-
case (e, _) =>
418-
e -> e
419-
}.unzip
420-
Project(casted._1, plan) -> Project(casted._2, plan)
406+
case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
407+
case (e, _) => e
408+
}
409+
Project(casted, plan)
421410
}
422411
}
423412

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.plans
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.spark.sql.AnalysisException
2123
import org.apache.spark.sql.catalyst.expressions._
2224
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode, TreeNodeTag}
@@ -168,6 +170,85 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
168170
}.toSeq
169171
}
170172

173+
/**
174+
* A variant of `transformUp`, which takes care of the case that the rule replaces a plan node
175+
* with a new one that has different output expr IDs, by updating the attribute references in
176+
* the parent nodes accordingly.
177+
*
178+
* @param rule which
179+
*/
180+
def transformUpWithNewOutput(
181+
rule: PartialFunction[PlanType, (PlanType, Seq[(Attribute, Attribute)])],
182+
skipCond: PlanType => Boolean = _ => false): PlanType = {
183+
def rewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = {
184+
if (skipCond(plan)) {
185+
plan -> Nil
186+
} else {
187+
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
188+
var newPlan = plan.mapChildren { child =>
189+
val (newChild, childAttrMapping) = rewrite(child)
190+
attrMapping ++= childAttrMapping
191+
newChild
192+
}
193+
194+
val attrMappingForCurrentPlan = attrMapping.filter {
195+
// The `attrMappingForCurrentPlan` is used to replace the attributes of the
196+
// current `plan`, so the `oldAttr` must be part of `plan.references`.
197+
case (oldAttr, _) => plan.references.contains(oldAttr)
198+
}
199+
200+
val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) {
201+
rule.applyOrElse(newPlan, (plan: PlanType) => plan -> Nil)
202+
}
203+
newPlan = planAfterRule
204+
205+
if (attrMappingForCurrentPlan.nonEmpty) {
206+
assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId)
207+
.exists(_._2.map(_._2.exprId).distinct.length > 1),
208+
"Found duplicate rewrite attributes")
209+
210+
val attributeRewrites = AttributeMap(attrMappingForCurrentPlan)
211+
// Using attrMapping from the children plans to rewrite their parent node.
212+
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
213+
newPlan = newPlan.transformExpressions {
214+
case a: AttributeReference =>
215+
updateAttr(a, attributeRewrites)
216+
case pe: PlanExpression[PlanType] =>
217+
pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attributeRewrites))
218+
}
219+
}
220+
221+
attrMapping ++= newAttrMapping.filter {
222+
case (a1, a2) => a1.exprId != a2.exprId
223+
}
224+
newPlan -> attrMapping
225+
}
226+
}
227+
rewrite(this)._1
228+
}
229+
230+
private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
231+
val exprId = attrMap.getOrElse(attr, attr).exprId
232+
attr.withExprId(exprId)
233+
}
234+
235+
/**
236+
* The outer plan may have old references and the function below updates the
237+
* outer references to refer to the new attributes.
238+
*/
239+
private def updateOuterReferencesInSubquery(
240+
plan: PlanType,
241+
attrMap: AttributeMap[Attribute]): PlanType = {
242+
plan.transformDown { case currentFragment =>
243+
currentFragment.transformExpressions {
244+
case OuterReference(a: AttributeReference) =>
245+
OuterReference(updateAttr(a, attrMap))
246+
case pe: PlanExpression[PlanType] =>
247+
pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap))
248+
}
249+
}
250+
}
251+
171252
lazy val schema: StructType = StructType.fromAttributes(output)
172253

173254
/** Returns the output schema in the tree format. */

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.sql.catalyst.analysis.CheckAnalysis
21-
import org.apache.spark.sql.catalyst.expressions.Expression
21+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
2222
import org.apache.spark.sql.catalyst.plans.QueryPlan
2323
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode}
2424
import org.apache.spark.util.Utils
@@ -120,6 +120,19 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
120120
}
121121
}
122122

123+
/**
124+
* A variant of `transformUpWithNewOutput`, which skips touching already analyzed plan.
125+
*/
126+
def resolveOperatorsUpWithNewOutput(
127+
rule: PartialFunction[LogicalPlan, (LogicalPlan, Seq[(Attribute, Attribute)])])
128+
: LogicalPlan = {
129+
if (!analyzed) {
130+
transformUpWithNewOutput(rule, skipCond = _.analyzed)
131+
} else {
132+
self
133+
}
134+
}
135+
123136
/**
124137
* Recursively transforms the expressions of a tree, skipping nodes that have already
125138
* been analyzed.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,12 +1419,13 @@ class TypeCoercionSuite extends AnalysisTest {
14191419
test("SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes") {
14201420
val t1 = LocalRelation(AttributeReference("v", DecimalType(10, 0))())
14211421
val t2 = LocalRelation(AttributeReference("v", DecimalType(11, 0))())
1422-
val p1 = t1.select(t1.output.head)
1423-
val p2 = t2.select(t2.output.head)
1422+
val p1 = t1.select(t1.output.head).as("p1")
1423+
val p2 = t2.select(t2.output.head).as("p2")
14241424
val union = p1.union(p2)
1425-
val wp1 = widenSetOperationTypes(union.select(p1.output.head))
1425+
val wp1 = widenSetOperationTypes(union.select(p1.output.head, $"p2.v"))
14261426
assert(wp1.isInstanceOf[Project])
1427-
assert(wp1.missingInput.isEmpty)
1427+
// The attribute `p1.output.head` should be replaced in the root `Project`.
1428+
assert(wp1.expressions.forall(_.find(_ == p1.output.head).isEmpty))
14281429
val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union))
14291430
assert(wp2.isInstanceOf[Aggregate])
14301431
assert(wp2.missingInput.isEmpty)

0 commit comments

Comments
 (0)