Skip to content

Commit a647466

Browse files
rxingatorsmile
authored andcommitted
[SPARK-20867][SQL] Move hints from Statistics into HintInfo class
## What changes were proposed in this pull request? This is a follow-up to SPARK-20857 to move the broadcast hint from Statistics into a new HintInfo class, so we can be more flexible in adding new hints in the future. ## How was this patch tested? Updated test cases to reflect the change. Author: Reynold Xin <rxin@databricks.com> Closes #18087 from rxin/SPARK-20867.
1 parent f72ad30 commit a647466

File tree

11 files changed

+59
-48
lines changed

11 files changed

+59
-48
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ object ResolveHints {
5757
val newNode = CurrentOrigin.withOrigin(plan.origin) {
5858
plan match {
5959
case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) =>
60-
ResolvedHint(plan, isBroadcastable = Option(true))
60+
ResolvedHint(plan, HintInfo(isBroadcastable = Option(true)))
6161
case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) =>
62-
ResolvedHint(plan, isBroadcastable = Option(true))
62+
ResolvedHint(plan, HintInfo(isBroadcastable = Option(true)))
6363

6464
case _: ResolvedHint | _: View | _: With | _: SubqueryAlias =>
6565
// Don't traverse down these nodes.
@@ -88,7 +88,7 @@ object ResolveHints {
8888
case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
8989
if (h.parameters.isEmpty) {
9090
// If there is no table alias specified, turn the entire subtree into a BroadcastHint.
91-
ResolvedHint(h.child, isBroadcastable = Option(true))
91+
ResolvedHint(h.child, HintInfo(isBroadcastable = Option(true)))
9292
} else {
9393
// Otherwise, find within the subtree query plans that should be broadcasted.
9494
applyBroadcastHint(h.child, h.parameters.toSet)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ abstract class UnaryNode extends LogicalPlan {
347347
}
348348

349349
// Don't propagate rowCount and attributeStats, since they are not estimated here.
350-
Statistics(sizeInBytes = sizeInBytes, isBroadcastable = child.stats(conf).isBroadcastable)
350+
Statistics(sizeInBytes = sizeInBytes, hints = child.stats(conf).hints)
351351
}
352352
}
353353

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ import org.apache.spark.util.Utils
4646
* defaults to the product of children's `sizeInBytes`.
4747
* @param rowCount Estimated number of rows.
4848
* @param attributeStats Statistics for Attributes.
49-
* @param isBroadcastable If true, output is small enough to be used in a broadcast join.
49+
* @param hints Query hints.
5050
*/
5151
case class Statistics(
5252
sizeInBytes: BigInt,
5353
rowCount: Option[BigInt] = None,
5454
attributeStats: AttributeMap[ColumnStat] = AttributeMap(Nil),
55-
isBroadcastable: Boolean = false) {
55+
hints: HintInfo = HintInfo()) {
5656

5757
override def toString: String = "Statistics(" + simpleString + ")"
5858

@@ -65,14 +65,9 @@ case class Statistics(
6565
} else {
6666
""
6767
},
68-
s"isBroadcastable=$isBroadcastable"
68+
s"hints=$hints"
6969
).filter(_.nonEmpty).mkString(", ")
7070
}
71-
72-
/** Must be called when computing stats for a join operator to reset hints. */
73-
def resetHintsForJoin(): Statistics = copy(
74-
isBroadcastable = false
75-
)
7671
}
7772

7873

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,9 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
195195
val leftSize = left.stats(conf).sizeInBytes
196196
val rightSize = right.stats(conf).sizeInBytes
197197
val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
198-
val isBroadcastable = left.stats(conf).isBroadcastable || right.stats(conf).isBroadcastable
199-
200-
Statistics(sizeInBytes = sizeInBytes, isBroadcastable = isBroadcastable)
198+
Statistics(
199+
sizeInBytes = sizeInBytes,
200+
hints = left.stats(conf).hints.resetForJoin())
201201
}
202202
}
203203

@@ -364,7 +364,8 @@ case class Join(
364364
case _ =>
365365
// Make sure we don't propagate isBroadcastable in other joins, because
366366
// they could explode the size.
367-
super.computeStats(conf).resetHintsForJoin()
367+
val stats = super.computeStats(conf)
368+
stats.copy(hints = stats.hints.resetForJoin())
368369
}
369370

370371
if (conf.cboEnabled) {
@@ -560,7 +561,7 @@ case class Aggregate(
560561
Statistics(
561562
sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1),
562563
rowCount = Some(1),
563-
isBroadcastable = child.stats(conf).isBroadcastable)
564+
hints = child.stats(conf).hints)
564565
} else {
565566
super.computeStats(conf)
566567
}
@@ -749,7 +750,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
749750
Statistics(
750751
sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats),
751752
rowCount = Some(rowCount),
752-
isBroadcastable = childStats.isBroadcastable)
753+
hints = childStats.hints)
753754
}
754755
}
755756

@@ -770,7 +771,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
770771
Statistics(
771772
sizeInBytes = 1,
772773
rowCount = Some(0),
773-
isBroadcastable = childStats.isBroadcastable)
774+
hints = childStats.hints)
774775
} else {
775776
// The output row count of LocalLimit should be the sum of row counts from each partition.
776777
// However, since the number of partitions is not available here, we just use statistics of
@@ -827,7 +828,7 @@ case class Sample(
827828
}
828829
val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio))
829830
// Don't propagate column stats, because we don't know the distribution after a sample operation
830-
Statistics(sizeInBytes, sampledRowCount, isBroadcastable = childStats.isBroadcastable)
831+
Statistics(sizeInBytes, sampledRowCount, hints = childStats.hints)
831832
}
832833

833834
override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,31 @@ case class UnresolvedHint(name: String, parameters: Seq[String], child: LogicalP
3535
/**
3636
* A resolved hint node. The analyzer should convert all [[UnresolvedHint]] into [[ResolvedHint]].
3737
*/
38-
case class ResolvedHint(
39-
child: LogicalPlan,
40-
isBroadcastable: Option[Boolean] = None)
38+
case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo())
4139
extends UnaryNode {
4240

4341
override def output: Seq[Attribute] = child.output
4442

4543
override def computeStats(conf: SQLConf): Statistics = {
4644
val stats = child.stats(conf)
47-
isBroadcastable.map(x => stats.copy(isBroadcastable = x)).getOrElse(stats)
45+
stats.copy(hints = hints)
46+
}
47+
}
48+
49+
50+
case class HintInfo(
51+
isBroadcastable: Option[Boolean] = None) {
52+
53+
/** Must be called when computing stats for a join operator to reset hints. */
54+
def resetForJoin(): HintInfo = copy(
55+
isBroadcastable = None
56+
)
57+
58+
override def toString: String = {
59+
if (productIterator.forall(_.asInstanceOf[Option[_]].isEmpty)) {
60+
"none"
61+
} else {
62+
isBroadcastable.map(x => s"isBroadcastable=$x").getOrElse("")
63+
}
4864
}
4965
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ object AggregateEstimation {
5656
sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats),
5757
rowCount = Some(outputRows),
5858
attributeStats = outputAttrStats,
59-
isBroadcastable = childStats.isBroadcastable))
59+
hints = childStats.hints))
6060
} else {
6161
None
6262
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,17 @@ class ResolveHintsSuite extends AnalysisTest {
3636
test("case-sensitive or insensitive parameters") {
3737
checkAnalysis(
3838
UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
39-
ResolvedHint(testRelation, isBroadcastable = Option(true)),
39+
ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
4040
caseSensitive = false)
4141

4242
checkAnalysis(
4343
UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")),
44-
ResolvedHint(testRelation, isBroadcastable = Option(true)),
44+
ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
4545
caseSensitive = false)
4646

4747
checkAnalysis(
4848
UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
49-
ResolvedHint(testRelation, isBroadcastable = Option(true)),
49+
ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
5050
caseSensitive = true)
5151

5252
checkAnalysis(
@@ -58,28 +58,28 @@ class ResolveHintsSuite extends AnalysisTest {
5858
test("multiple broadcast hint aliases") {
5959
checkAnalysis(
6060
UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))),
61-
Join(ResolvedHint(testRelation, isBroadcastable = Option(true)),
62-
ResolvedHint(testRelation2, isBroadcastable = Option(true)), Inner, None),
61+
Join(ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
62+
ResolvedHint(testRelation2, HintInfo(isBroadcastable = Option(true))), Inner, None),
6363
caseSensitive = false)
6464
}
6565

6666
test("do not traverse past existing broadcast hints") {
6767
checkAnalysis(
6868
UnresolvedHint("MAPJOIN", Seq("table"),
69-
ResolvedHint(table("table").where('a > 1), isBroadcastable = Option(true))),
70-
ResolvedHint(testRelation.where('a > 1), isBroadcastable = Option(true)).analyze,
69+
ResolvedHint(table("table").where('a > 1), HintInfo(isBroadcastable = Option(true)))),
70+
ResolvedHint(testRelation.where('a > 1), HintInfo(isBroadcastable = Option(true))).analyze,
7171
caseSensitive = false)
7272
}
7373

7474
test("should work for subqueries") {
7575
checkAnalysis(
7676
UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")),
77-
ResolvedHint(testRelation, isBroadcastable = Option(true)),
77+
ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
7878
caseSensitive = false)
7979

8080
checkAnalysis(
8181
UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)),
82-
ResolvedHint(testRelation, isBroadcastable = Option(true)),
82+
ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))),
8383
caseSensitive = false)
8484

8585
// Negative case: if the alias doesn't match, don't match the original table name.
@@ -104,7 +104,7 @@ class ResolveHintsSuite extends AnalysisTest {
104104
|SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable
105105
""".stripMargin
106106
),
107-
ResolvedHint(testRelation.where('a > 1).select('a), isBroadcastable = Option(true))
107+
ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(isBroadcastable = Option(true)))
108108
.select('a).analyze,
109109
caseSensitive = false)
110110
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,20 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
3737

3838
test("BroadcastHint estimation") {
3939
val filter = Filter(Literal(true), plan)
40-
val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false,
40+
val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4),
4141
rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> colStat)))
42-
val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false)
42+
val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4))
4343
checkStats(
4444
filter,
4545
expectedStatsCboOn = filterStatsCboOn,
4646
expectedStatsCboOff = filterStatsCboOff)
4747

48-
val broadcastHint = ResolvedHint(filter, isBroadcastable = Option(true))
48+
val broadcastHint = ResolvedHint(filter, HintInfo(isBroadcastable = Option(true)))
4949
checkStats(
5050
broadcastHint,
51-
expectedStatsCboOn = filterStatsCboOn.copy(isBroadcastable = true),
52-
expectedStatsCboOff = filterStatsCboOff.copy(isBroadcastable = true))
51+
expectedStatsCboOn = filterStatsCboOn.copy(hints = HintInfo(isBroadcastable = Option(true))),
52+
expectedStatsCboOff = filterStatsCboOff.copy(hints = HintInfo(isBroadcastable = Option(true)))
53+
)
5354
}
5455

5556
test("limit estimation: limit < child's rowCount") {
@@ -94,15 +95,13 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
9495
sizeInBytes = 40,
9596
rowCount = Some(10),
9697
attributeStats = AttributeMap(Seq(
97-
AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))),
98-
isBroadcastable = false)
98+
AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))))
9999
val expectedCboStats =
100100
Statistics(
101101
sizeInBytes = 4,
102102
rowCount = Some(1),
103103
attributeStats = AttributeMap(Seq(
104-
AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))),
105-
isBroadcastable = false)
104+
AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))))
106105

107106
val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats)
108107
checkStats(

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
114114
* Matches a plan whose output should be small enough to be used in broadcast join.
115115
*/
116116
private def canBroadcast(plan: LogicalPlan): Boolean = {
117-
plan.stats(conf).isBroadcastable ||
117+
plan.stats(conf).hints.isBroadcastable.getOrElse(false) ||
118118
(plan.stats(conf).sizeInBytes >= 0 &&
119119
plan.stats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold)
120120
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction}
2929
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
3030
import org.apache.spark.sql.catalyst.expressions._
3131
import org.apache.spark.sql.catalyst.expressions.aggregate._
32-
import org.apache.spark.sql.catalyst.plans.logical.ResolvedHint
32+
import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint}
3333
import org.apache.spark.sql.execution.SparkSqlParser
3434
import org.apache.spark.sql.expressions.UserDefinedFunction
3535
import org.apache.spark.sql.internal.SQLConf
@@ -1020,7 +1020,7 @@ object functions {
10201020
*/
10211021
def broadcast[T](df: Dataset[T]): Dataset[T] = {
10221022
Dataset[T](df.sparkSession,
1023-
ResolvedHint(df.logicalPlan, isBroadcastable = Option(true)))(df.exprEnc)
1023+
ResolvedHint(df.logicalPlan, HintInfo(isBroadcastable = Option(true))))(df.exprEnc)
10241024
}
10251025

10261026
/**

sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
164164
numbers.foreach { case (input, (expectedSize, expectedRows)) =>
165165
val stats = Statistics(sizeInBytes = input, rowCount = Some(input))
166166
val expectedString = s"sizeInBytes=$expectedSize, rowCount=$expectedRows," +
167-
s" isBroadcastable=${stats.isBroadcastable}"
167+
s" hints=none"
168168
assert(stats.simpleString == expectedString)
169169
}
170170
}

0 commit comments

Comments
 (0)