Skip to content

Commit 83483b1

Browse files
author
mingbo.pb
committed
back to use propagate template method in QueryPlanner.plan for extension strategy compability
1 parent b3cd0b0 commit 83483b1

File tree

10 files changed

+32
-29
lines changed

10 files changed

+32
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] {
6060
// Obviously a lot to do here still...
6161

6262
// Collect physical plan candidates.
63-
val candidates = strategies.iterator.flatMap(_(plan))
63+
val candidates = strategies.iterator.flatMap({
64+
_(plan).map(propagate(_, plan))
65+
})
6466

6567
// The candidates may contain placeholders marked as [[planLater]],
6668
// so try to replace them by their child plans.
@@ -102,4 +104,7 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] {
102104

103105
/** Prunes bad plans to prevent combinatorial explosion. */
104106
protected def prunePlans(plans: Iterator[PhysicalPlan]): Iterator[PhysicalPlan]
107+
108+
/** Propagate logicalPlan properties to PhysicalPlan */
109+
protected def propagate(plan: PhysicalPlan, logicalPlan: LogicalPlan): PhysicalPlan
105110
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ class SparkPlanner(
6565
plans
6666
}
6767

68+
override protected def propagate(plan: SparkPlan, logicalPlan: LogicalPlan): SparkPlan = {
69+
plan.withStats(logicalPlan.stats)
70+
}
71+
6872
/**
6973
* Used to build table scan operators where complex projection and filtering are done using
7074
* separate physical operators. This function returns the given scan operator with Project and

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,6 @@ import org.apache.spark.sql.types.StructType
4949
abstract class SparkStrategy extends GenericStrategy[SparkPlan] {
5050

5151
override protected def planLater(plan: LogicalPlan): SparkPlan = PlanLater(plan)
52-
53-
override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
54-
doApply(plan).map(sparkPlan => sparkPlan.withStats(plan.stats))
55-
}
56-
57-
protected def doApply(plan: LogicalPlan): Seq[SparkPlan]
5852
}
5953

6054
case class PlanLater(plan: LogicalPlan) extends LeafExecNode {
@@ -73,7 +67,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
7367
* Plans special cases of limit operators.
7468
*/
7569
object SpecialLimits extends Strategy {
76-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
70+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
7771
case ReturnAnswer(rootPlan) => rootPlan match {
7872
case Limit(IntegerLiteral(limit), Sort(order, true, child))
7973
if limit < conf.topKSortFallbackThreshold =>
@@ -215,7 +209,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
215209
hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL))
216210
}
217211

218-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
212+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
219213

220214
// If it is an equi-join, we first look at the join hints w.r.t. the following order:
221215
// 1. broadcast hint: pick broadcast hash join if the join type is supported. If both sides
@@ -389,7 +383,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
389383
* on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]]
390384
*/
391385
object StatefulAggregationStrategy extends Strategy {
392-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
386+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
393387
case _ if !plan.isStreaming => Nil
394388

395389
case EventTimeWatermark(columnName, delay, child) =>
@@ -429,7 +423,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
429423
* Used to plan the streaming deduplicate operator.
430424
*/
431425
object StreamingDeduplicationStrategy extends Strategy {
432-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
426+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
433427
case Deduplicate(keys, child) if child.isStreaming =>
434428
StreamingDeduplicateExec(keys, planLater(child)) :: Nil
435429

@@ -446,7 +440,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
446440
* Limit is unsupported for streams in Update mode.
447441
*/
448442
case class StreamingGlobalLimitStrategy(outputMode: OutputMode) extends Strategy {
449-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
443+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
450444
case ReturnAnswer(rootPlan) => rootPlan match {
451445
case Limit(IntegerLiteral(limit), child)
452446
if plan.isStreaming && outputMode == InternalOutputModes.Append =>
@@ -461,7 +455,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
461455
}
462456

463457
object StreamingJoinStrategy extends Strategy {
464-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = {
458+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
465459
plan match {
466460
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _)
467461
if left.isStreaming && right.isStreaming =>
@@ -482,7 +476,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
482476
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
483477
*/
484478
object Aggregation extends Strategy {
485-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
479+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
486480
case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child)
487481
if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) =>
488482
val aggregateExpressions = aggExpressions.map(expr =>
@@ -544,7 +538,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
544538
}
545539

546540
object Window extends Strategy {
547-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
541+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
548542
case PhysicalWindow(
549543
WindowFunctionType.SQL, windowExprs, partitionSpec, orderSpec, child) =>
550544
execution.window.WindowExec(
@@ -562,7 +556,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
562556
protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)
563557

564558
object InMemoryScans extends Strategy {
565-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
559+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
566560
case PhysicalOperation(projectList, filters, mem: InMemoryRelation) =>
567561
pruneFilterProject(
568562
projectList,
@@ -580,7 +574,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
580574
* be replaced with the real relation using the `Source` in `StreamExecution`.
581575
*/
582576
object StreamingRelationStrategy extends Strategy {
583-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
577+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
584578
case s: StreamingRelation =>
585579
StreamingRelationExec(s.sourceName, s.output) :: Nil
586580
case s: StreamingExecutionRelation =>
@@ -596,7 +590,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
596590
* in streaming plans. Conversion for batch plans is handled by [[BasicOperators]].
597591
*/
598592
object FlatMapGroupsWithStateStrategy extends Strategy {
599-
override def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
593+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
600594
case FlatMapGroupsWithState(
601595
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _,
602596
timeout, child) =>
@@ -614,7 +608,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
614608
* Strategy to convert EvalPython logical operator to physical operator.
615609
*/
616610
object PythonEvals extends Strategy {
617-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
611+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
618612
case ArrowEvalPython(udfs, output, child) =>
619613
ArrowEvalPythonExec(udfs, output, planLater(child)) :: Nil
620614
case BatchEvalPython(udfs, output, child) =>
@@ -625,7 +619,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
625619
}
626620

627621
object BasicOperators extends Strategy {
628-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
622+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
629623
case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil
630624
case r: RunnableCommand => ExecutedCommandExec(r) :: Nil
631625

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
261261
case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with CastSupport {
262262
import DataSourceStrategy._
263263

264-
override protected def doApply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
264+
def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
265265
case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _, _)) =>
266266
pruneFilterProjectRaw(
267267
l,

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ object FileSourceStrategy extends Strategy with Logging {
136136
}
137137
}
138138

139-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
139+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
140140
case PhysicalOperation(projects, filters,
141141
l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) =>
142142
// Filters on this relation fall into four categories based on where we can use them to avoid

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
102102

103103
import DataSourceV2Implicits._
104104

105-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
105+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
106106
case PhysicalOperation(project, filters, relation: DataSourceV2Relation) =>
107107
val scanBuilder = relation.newScanBuilder()
108108

sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan {
3939
}
4040

4141
object TestStrategy extends Strategy {
42-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
42+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
4343
case Project(Seq(attr), _) if attr.name == "a" =>
4444
FastOperator(attr.toAttribute :: Nil) :: Nil
4545
case _ => Nil

sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ case class MyCheckRule(spark: SparkSession) extends (LogicalPlan => Unit) {
218218
}
219219

220220
case class MySparkStrategy(spark: SparkSession) extends SparkStrategy {
221-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty
221+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty
222222
}
223223

224224
case class MyParser(spark: SparkSession, delegate: ParserInterface) extends ParserInterface {
@@ -272,7 +272,7 @@ case class MyCheckRule2(spark: SparkSession) extends (LogicalPlan => Unit) {
272272
}
273273

274274
case class MySparkStrategy2(spark: SparkSession) extends SparkStrategy {
275-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty
275+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty
276276
}
277277

278278
object MyExtensions2 {

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class SparkPlannerSuite extends SharedSQLContext {
3333

3434
var planned = 0
3535
object TestStrategy extends Strategy {
36-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
36+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
3737
case ReturnAnswer(child) =>
3838
planned += 1
3939
planLater(child) :: planLater(NeverPlanned) :: Nil

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ private[hive] trait HiveStrategies {
223223
val sparkSession: SparkSession
224224

225225
object Scripts extends Strategy {
226-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
226+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
227227
case ScriptTransformation(input, script, output, child, ioschema) =>
228228
val hiveIoSchema = HiveScriptIOSchema(ioschema)
229229
ScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil
@@ -236,7 +236,7 @@ private[hive] trait HiveStrategies {
236236
* applied.
237237
*/
238238
object HiveTableScans extends Strategy {
239-
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
239+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
240240
case PhysicalOperation(projectList, predicates, relation: HiveTableRelation) =>
241241
// Filter out all predicates that only deal with partition keys, these are given to the
242242
// hive table scan operator to be used for partition pruning.

0 commit comments

Comments
 (0)