Skip to content

Commit dafad5c

Browse files
committed
always carry over the tags in transform
1 parent 9f377d7 commit dafad5c

File tree

1 file changed

+5
-13
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees

1 file changed

+5
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
8787

8888
/**
8989
* A mutable map for holding auxiliary information of this tree node. It will be carried over
90-
* when this node is copied via `makeCopy`. The tags will be kept after transforming, if
91-
* the node is transformed to the same type. Otherwise, tags will be dropped.
90+
* when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`.
9291
*/
9392
val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty
9493

@@ -273,12 +272,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
273272
if (this fastEquals afterRule) {
274273
mapChildren(_.transformDown(rule))
275274
} else {
276-
// If the transform function replaces this node with a new one of the same type, carry over
277-
// the tags.
278-
if (afterRule.getClass == this.getClass) {
279-
afterRule.tags ++= this.tags
280-
}
281-
275+
// If the transform function replaces this node with a new one, carry over the tags.
276+
afterRule.tags ++= this.tags
282277
afterRule.mapChildren(_.transformDown(rule))
283278
}
284279
}
@@ -301,11 +296,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
301296
rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
302297
}
303298
}
304-
// If the transform function replaces this node with a new one of the same type, carry over
305-
// the tags.
306-
if (newNode.getClass == this.getClass) {
307-
newNode.tags ++= this.tags
308-
}
299+
// If the transform function replaces this node with a new one, carry over the tags.
300+
newNode.tags ++= this.tags
309301
newNode
310302
}
311303

0 commit comments

Comments
 (0)