Skip to content

[SPARK-27816][SQL] make TreeNode tag type safe #24687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ case class OneRowRelation() extends LeafNode {
/** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */
override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = {
val newCopy = OneRowRelation()
newCopy.tags ++= this.tags
newCopy.copyTagsFrom(this)
newCopy
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ object CurrentOrigin {
}
}

// The name of the tree node tag. This is preferred over using string directly, as we can easily
// find all the defined tags.
case class TreeNodeTagName(name: String)
// A tag of a `TreeNode`, which defines name and type
case class TreeNodeTag[T](name: String)

// scalastyle:off
abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
Expand All @@ -89,7 +88,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* A mutable map for holding auxiliary information of this tree node. It will be carried over
* when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`.
*/
val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty
private val tags: mutable.Map[TreeNodeTag[_], Any] = mutable.Map.empty

protected def copyTagsFrom(other: BaseType): Unit = {
tags ++= other.tags
}

def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = {
tags(tag) = value
}

def getTagValue[T](tag: TreeNodeTag[T]): Option[T] = {
tags.get(tag).map(_.asInstanceOf[T])
}

/**
* Returns a Seq of the children of this node.
Expand Down Expand Up @@ -418,7 +429,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
try {
CurrentOrigin.withOrigin(origin) {
val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
res.tags ++= this.tags
res.copyTagsFrom(this)
res
}
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,31 +622,33 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
}

test("tags will be carried over after copy & transform") {
val tag = TreeNodeTag[String]("test")

withClue("makeCopy") {
val node = Dummy(None)
node.tags += TreeNodeTagName("test") -> "a"
node.setTagValue(tag, "a")
val copied = node.makeCopy(Array(Some(Literal(1))))
assert(copied.tags(TreeNodeTagName("test")) == "a")
assert(copied.getTagValue(tag) == Some("a"))
}

def checkTransform(
sameTypeTransform: Expression => Expression,
differentTypeTransform: Expression => Expression): Unit = {
val child = Dummy(None)
child.tags += TreeNodeTagName("test") -> "child"
child.setTagValue(tag, "child")
val node = Dummy(Some(child))
node.tags += TreeNodeTagName("test") -> "parent"
node.setTagValue(tag, "parent")

val transformed = sameTypeTransform(node)
// Both the child and parent keep the tags
assert(transformed.tags(TreeNodeTagName("test")) == "parent")
assert(transformed.children.head.tags(TreeNodeTagName("test")) == "child")
assert(transformed.getTagValue(tag) == Some("parent"))
assert(transformed.children.head.getTagValue(tag) == Some("child"))

val transformed2 = differentTypeTransform(node)
// Both the child and parent keep the tags, even if we transform the node to a new one of
// different type.
assert(transformed2.tags(TreeNodeTagName("test")) == "parent")
assert(transformed2.children.head.tags.contains(TreeNodeTagName("test")))
assert(transformed2.getTagValue(tag) == Some("parent"))
assert(transformed2.children.head.getTagValue(tag) == Some("child"))
}

withClue("transformDown") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,16 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.trees.TreeNodeTagName
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.DataType

object SparkPlan {
// a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag
// when converting a logical plan to a physical plan.
val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan")
val LOGICAL_PLAN_TAG = TreeNodeTag[LogicalPlan]("logical_plan")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ReturnAnswer(rootPlan) => rootPlan
case _ => plan
}
p.tags += SparkPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan
p.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, logicalPlan)
p
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import scala.reflect.ClassTag

import org.apache.spark.sql.TPCDSQuerySuite
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
Expand Down Expand Up @@ -81,12 +80,12 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite {
// The exchange related nodes are created after the planning, they don't have corresponding
// logical plan.
case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec =>
assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty)

// The subquery exec nodes are just wrappers of the actual nodes, they don't have
// corresponding logical plan.
case _: SubqueryExec | _: ReusedSubqueryExec =>
assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty)

case _ if isScanPlanTree(plan) =>
// The strategies for planning scan can remove or add FilterExec/ProjectExec nodes,
Expand Down Expand Up @@ -120,9 +119,9 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite {
}

private def getLogicalPlan(node: SparkPlan): LogicalPlan = {
assert(node.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME),
node.getClass.getSimpleName + " does not have a logical plan link")
node.tags(SparkPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan]
node.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).getOrElse {
fail(node.getClass.getSimpleName + " does not have a logical plan link")
}
}

private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = {
Expand Down