Skip to content

[SPARK-28148][SQL] Repartition after join is not optimized away #27096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan}
import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, PruneShuffleAndSort, ReuseExchange}
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
Expand Down Expand Up @@ -285,6 +285,7 @@ object QueryExecution {
PlanDynamicPruningFilters(sparkSession),
PlanSubqueries(sparkSession),
EnsureRequirements(sparkSession.sessionState.conf),
PruneShuffleAndSort(),
ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.conf,
sparkSession.sessionState.columnarRules),
CollapseCodegenStages(sparkSession.sessionState.conf),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
}

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
// TODO: remove this after we create a physical operator for `RepartitionByExpression`.
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) =>
child.outputPartitioning match {
case lower: HashPartitioning if upper.semanticEquals(lower) => child
case _ => operator
}
case operator: SparkPlan =>
ensureDistributionAndOrdering(reorderJoinPredicates(operator))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.exchange

import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{SortExec, SparkPlan}

/**
* Removes unnecessary shuffles and sorts after new ones are introduced by [[Rule]]s for
* [[SparkPlan]]s, such as [[EnsureRequirements]].
*/
case class PruneShuffleAndSort() extends Rule[SparkPlan] {

override def apply(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) =>
child.outputPartitioning match {
case lower: HashPartitioning if upper.semanticEquals(lower) => child
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why doesn't this apply to RangeParititoning? Is it just we assume repartition() would only do HashPartitioning?
And what happens if someone does df1.join(df2, Seq("id"), "left").repartition(100, df("some_other_column")).repartition(20, df1("id")) or df1.join(df2, Seq("id"), "left").sortWithinPartition(df1("some_other_column")).sortWithinPartition(df1("id")) ? We should be able to optimize that out too, right? It would be nice to make this rule more general and cover a wider range of cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for providing feedback.
Let me take a look into your specific examples and think a little more about it.

Copy link
Contributor Author

@bmarcott bmarcott Feb 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maryannxue
This PR focused on fixing removing unnecessary sorting and shuffling after a join, which potentially includes its own ShuffleExchangeExec with HashPartioning. Both cases you mentioned are already optimized properly: the shuffling on "some_other_column" is removed and all sortWithinPartitions are removed (due to previous optimizations in logical plan, and the optimizations introduced here)

I wouldn't mind generalizing to all Partitioning types of the ShuffleExchangeExec, but I am not sure how to compare two partitioning types for equality. You can see the special case for HashPartitioning in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case _ @ PartitioningCollection(partitionings) =>
if (partitionings.exists{
case lower: HashPartitioning => upper.semanticEquals(lower)
case _ => false
}) {
child
} else {
operator
}
case _ => operator
}
case SortExec(upper, false, child, _)
if SortOrder.orderingSatisfies(child.outputOrdering, upper) => child
case subPlan: SparkPlan => subPlan
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, PruneShuffleAndSort, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -433,7 +433,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
val inputPlan = ShuffleExchangeExec(
partitioning,
DummySparkPlan(outputPartitioning = partitioning))
val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
val outputPlan = PruneShuffleAndSort().apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) {
fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan")
Expand Down Expand Up @@ -727,6 +727,48 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}
}

test("SPARK-28148: repartition after join is not optimized away") {

def numSorts(plan: SparkPlan): Int = {
plan.collect{case s: SortExec => s }.length
}

def numShuffles(plan: SparkPlan): Int = {
plan.collect{case s: ShuffleExchangeExec => s }.length
}

val df1 = spark.range(0, 5000000, 1, 5)
val df2 = spark.range(0, 10000000, 1, 5)

val outputPlan0 = df1.join(df2, Seq("id"), "left")
.repartition(20, df1("id")).queryExecution.executedPlan
assert(numSorts(outputPlan0) == 2)
assert(numShuffles(outputPlan0) == 3, "user defined numPartitions shouldn't be eliminated")

// non global sort order and partitioning should be reusable after left join
val outputPlan1 = df1.join(df2, Seq("id"), "left")
.repartition(df1("id"))
.sortWithinPartitions(df1("id"))
.queryExecution.executedPlan
assert(numSorts(outputPlan1) == 2)
assert(numShuffles(outputPlan1) == 2)

// non global sort order and partitioning should be reusable after inner join
val outputPlan2 = df1.join(df2, Seq("id"))
.repartition(df1("id"))
.sortWithinPartitions(df1("id"))
.queryExecution.executedPlan
assert(numSorts(outputPlan2) == 2)
assert(numShuffles(outputPlan2) == 2)

// global sort should not be removed
val outputPlan3 = df1.join(df2, Seq("id"))
.orderBy(df1("id"))
.queryExecution.executedPlan
assert(numSorts(outputPlan3) == 3)
assert(numShuffles(outputPlan3) == 3)
}

test("SPARK-24500: create union with stream of children") {
val df = Union(Stream(
Range(1, 1, 1, 1),
Expand Down