Skip to content

Commit 61bd1c9

Browse files
Guo Chenzhaoplusplusjiajia
authored andcommitted
Reduce shuffles for successive full outer join (apache#63)
* Modify existed Partioning & Distribution to reduce shuffles for full outer join * Refactor and test
1 parent 471fd59 commit 61bd1c9

File tree

5 files changed

+78
-21
lines changed

5 files changed

+78
-21
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,22 @@ case class ClusteredDistribution(
9292
}
9393

9494
/**
95-
* Represents data where tuples have been clustered according to the hash of the given
96-
* `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only
95+
* If exceptNull == false: Represents data where tuples have been clustered according to the hash of
96+
* the given `expressions`.
97+
* If exceptNull == true: Represents data where tuples have been clustered according to the hash of
98+
* the given `expressions` except NULL, it means NULL can distribute in any partitions. This is
99+
* often used in conditions of Join, where NULL's distribution is not cared about due to NULL will
100+
* be considered not equal to any value
101+
* The hash function is defined as `HashPartitioning.partitionIdExpression`, so only
97102
* [[HashPartitioning]] can satisfy this distribution.
98103
*
99104
* This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the
100105
* number of partitions, this distribution strictly requires which partition the tuple should be in.
101106
*/
102107
case class HashClusteredDistribution(
103108
expressions: Seq[Expression],
104-
requiredNumPartitions: Option[Int] = None) extends Distribution {
109+
requiredNumPartitions: Option[Int] = None,
110+
exceptNull: Boolean = false) extends Distribution {
105111
require(
106112
expressions != Nil,
107113
"The expressions for hash of a HashClusteredDistribution should not be Nil. " +
@@ -112,7 +118,7 @@ case class HashClusteredDistribution(
112118
assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
113119
s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
114120
s"the actual number of partitions is $numPartitions.")
115-
HashPartitioning(expressions, numPartitions)
121+
HashPartitioning(expressions, numPartitions, exceptNull)
116122
}
117123
}
118124

@@ -207,12 +213,16 @@ case object SinglePartition extends Partitioning {
207213
}
208214

209215
/**
210-
* Represents a partitioning where rows are split up across partitions based on the hash
211-
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
216+
* If exceptNull == false: Represents a partitioning where rows are split up across partitions based
217+
* on the hash of `expressions`.
218+
* If exceptNull == true: Represents a partitioning where rows are split up across partitions based
219+
* on the hash of `expressions` except null, which is the only key not co-partitioned.
220+
* All rows where `expressions` evaluate to the same values are guaranteed to be
212221
* in the same partition.
213222
*/
214-
case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
215-
extends Expression with Partitioning with Unevaluable {
223+
case class HashPartitioning(
224+
expressions: Seq[Expression], numPartitions: Int, exceptNull: Boolean = false)
225+
extends Expression with Partitioning with Unevaluable {
216226

217227
override def children: Seq[Expression] = expressions
218228
override def nullable: Boolean = false
@@ -222,8 +232,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
222232
super.satisfies0(required) || {
223233
required match {
224234
case h: HashClusteredDistribution =>
225-
expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
226-
case (l, r) => l.semanticEquals(r)
235+
expressions.length == h.expressions.length && (h.exceptNull || !exceptNull) &&
236+
expressions.zip(h.expressions).forall {
237+
case (l, r) => l.semanticEquals(r)
227238
}
228239
case ClusteredDistribution(requiredClustering, _) =>
229240
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.physical._
2525
import org.apache.spark.sql.catalyst.rules.Rule
2626
import org.apache.spark.sql.execution._
27-
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec,
28-
SortMergeJoinExec}
27+
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
2928
import org.apache.spark.sql.internal.SQLConf
3029

3130
/**
@@ -163,13 +162,13 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
163162
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
164163
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
165164
leftPartitioning match {
166-
case HashPartitioning(leftExpressions, _)
165+
case HashPartitioning(leftExpressions, _, _)
167166
if leftExpressions.length == leftKeys.length &&
168167
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
169168
reorder(leftKeys, rightKeys, leftExpressions, leftKeys)
170169

171170
case _ => rightPartitioning match {
172-
case HashPartitioning(rightExpressions, _)
171+
case HashPartitioning(rightExpressions, _, _)
173172
if rightExpressions.length == rightKeys.length &&
174173
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
175174
reorder(leftKeys, rightKeys, rightExpressions, rightKeys)

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ object ShuffleExchangeExec {
190190
serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = {
191191
val part: Partitioner = newPartitioning match {
192192
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
193-
case HashPartitioning(_, n) =>
193+
case HashPartitioning(_, n, _) =>
194194
new Partitioner {
195195
override def numPartitions: Int = n
196196
// For HashPartitioning, the partitioning key is already a valid partition ID, as we use

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,27 @@ case class SortMergeJoinExec(
7070
// For left and right outer joins, the output is partitioned by the streamed input's join keys.
7171
case LeftOuter => left.outputPartitioning
7272
case RightOuter => right.outputPartitioning
73-
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
73+
case FullOuter =>
74+
// The output of Full Outer Join is similar to pure HashPartioning, except for NULL, which
75+
// is the only key not co-partitioned
76+
(left.outputPartitioning, right.outputPartitioning) match {
77+
case (l: HashPartitioning, r: HashPartitioning) =>
78+
PartitioningCollection(Seq(l.copy(exceptNull = true), r.copy(exceptNull = true)))
79+
case _ => UnknownPartitioning(left.outputPartitioning.numPartitions)
80+
}
7481
case LeftExistence(_) => left.outputPartitioning
7582
case x =>
7683
throw new IllegalArgumentException(
7784
s"${getClass.getSimpleName} should not take $x as the JoinType")
7885
}
7986

80-
override def requiredChildDistribution: Seq[Distribution] =
81-
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
87+
override def requiredChildDistribution: Seq[Distribution] = joinType match {
88+
case Inner | LeftOuter | RightOuter | FullOuter =>
89+
HashClusteredDistribution(leftKeys, exceptNull = true) ::
90+
HashClusteredDistribution(rightKeys, exceptNull = true) :: Nil
91+
case _ => HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
92+
}
93+
8294

8395
override def outputOrdering: Seq[SortOrder] = joinType match {
8496
// For inner join, orders of both sides keys should be kept.

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,21 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.rdd.RDD
21-
import org.apache.spark.sql.{execution, Row}
21+
import org.apache.spark.sql.{QueryTest, Row, execution}
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter}
2525
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
2626
import org.apache.spark.sql.catalyst.plans.physical._
2727
import org.apache.spark.sql.execution.columnar.InMemoryRelation
28-
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
28+
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange, ReusedExchangeExec, ShuffleExchangeExec}
2929
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
3030
import org.apache.spark.sql.functions._
3131
import org.apache.spark.sql.internal.SQLConf
3232
import org.apache.spark.sql.test.SharedSQLContext
3333
import org.apache.spark.sql.types._
3434

35-
class PlannerSuite extends SharedSQLContext {
35+
class PlannerSuite extends QueryTest with SharedSQLContext {
3636
import testImplicits._
3737

3838
setupTestData()
@@ -683,6 +683,41 @@ class PlannerSuite extends SharedSQLContext {
683683
case _ => fail()
684684
}
685685
}
686+
test("EnsureRequirements doesn't add shuffle between 2 successive full outer joins on the same " +
687+
"key") {
688+
val df1 = spark.range(1, 100, 1, 2).filter(_ % 2 == 0).selectExpr("id as a1")
689+
val df2 = spark.range(1, 100, 1, 2).selectExpr("id as b2")
690+
val df3 = spark.range(1, 100, 1, 2).selectExpr("id as a3")
691+
val fullOuterJoins = df1
692+
.join(df2, col("a1") === col("b2"), "full_outer")
693+
.join(df3, col("a1") === col("a3"), "full_outer")
694+
assert(
695+
fullOuterJoins.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => e }
696+
.length === 3)
697+
val expected = (1 until 100).filter(_ % 2 == 0).map(i => Row(i, i, i)) ++
698+
(1 until 100).filterNot(_ % 2 == 0).map(Row(null, _, null)) ++
699+
(1 until 100).filterNot(_ % 2 == 0).map(Row(null, null, _))
700+
checkAnswer(fullOuterJoins, expected)
701+
}
702+
703+
test("EnsureRequirements still adds shuffle for non-successive full outer joins on the same key")
704+
{
705+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
706+
val df1 = spark.range(1, 100).selectExpr("id as a1")
707+
val df2 = spark.range(1, 100).selectExpr("id as b2")
708+
val df3 = spark.range(1, 100).selectExpr("id as a3")
709+
val df4 = spark.range(1, 100).selectExpr("id as a4")
710+
711+
val fullOuterJoins = df1
712+
.join(df2, col("a1") === col("b2"), "full_outer")
713+
.join(df3, col("a1") === col("a3"), "left_outer")
714+
.join(df4, col("a3") === col("a4"), "full_outer")
715+
fullOuterJoins.explain(true)
716+
assert(
717+
fullOuterJoins.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => e }
718+
.length === 5)
719+
}
720+
}
686721
}
687722

688723
// Used for unit-testing EnsureRequirements

0 commit comments

Comments
 (0)