Skip to content

Commit 59e0066

Browse files
aokolnychyidongjoon-hyun
authored andcommitted
[SPARK-42779][SQL] Allow V2 writes to indicate advisory shuffle partition size
### What changes were proposed in this pull request? This PR adds an API for data sources to indicate the advisory partition size for V2 writes. ### Why are the changes needed? Data sources have an API to request a particular distribution and ordering of data for V2 writes. If AQE is enabled, the default session advisory partition size (64MB) will be used as target. Unfortunately, this default value is still suboptimal and can lead to small files because the written data can be compressed nicely using columnar file formats. Spark should allow data sources to indicate the advisory shuffle partition size, just like it lets data sources request a particular number of partitions. This feature would allow data sources to estimate the compression ratio and incorporate that in the requested advisory partition size. ### Does this PR introduce _any_ user-facing change? Yes. However, the changes are backward compatible. ### How was this patch tested? This PR extends the existing tests for V2 write distribution and ordering. Closes apache#40421 from aokolnychyi/spark-42779. Lead-authored-by: aokolnychyi <aokolnychyi@apple.com> Co-authored-by: Anton Okolnychyi <aokolnychyi@apple.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent c238086 commit 59e0066

26 files changed

+290
-119
lines changed

core/src/main/resources/error/error-classes.json

+22-5
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,28 @@
10631063
],
10641064
"sqlState" : "42903"
10651065
},
1066+
"INVALID_WRITE_DISTRIBUTION" : {
1067+
"message" : [
1068+
"The requested write distribution is invalid."
1069+
],
1070+
"subClass" : {
1071+
"PARTITION_NUM_AND_SIZE" : {
1072+
"message" : [
1073+
"The partition number and advisory partition size can't be specified at the same time."
1074+
]
1075+
},
1076+
"PARTITION_NUM_WITH_UNSPECIFIED_DISTRIBUTION" : {
1077+
"message" : [
1078+
"The number of partitions can't be specified with unspecified distribution."
1079+
]
1080+
},
1081+
"PARTITION_SIZE_WITH_UNSPECIFIED_DISTRIBUTION" : {
1082+
"message" : [
1083+
"The advisory partition size can't be specified with unspecified distribution."
1084+
]
1085+
}
1086+
}
1087+
},
10661088
"LOCATION_ALREADY_EXISTS" : {
10671089
"message" : [
10681090
"Cannot name the managed table as <identifier>, as its associated location <location> already exists. Please pick a different table name, or remove the existing location first."
@@ -2931,11 +2953,6 @@
29312953
"Unsupported data type <dataType>."
29322954
]
29332955
},
2934-
"_LEGACY_ERROR_TEMP_1178" : {
2935-
"message" : [
2936-
"The number of partitions can't be specified with unspecified distribution. Invalid writer requirements detected."
2937-
]
2938-
},
29392956
"_LEGACY_ERROR_TEMP_1181" : {
29402957
"message" : [
29412958
"Stream-stream join without equality predicate is not supported."

sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java

+22-1
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,33 @@ public interface RequiresDistributionAndOrdering extends Write {
6666
* <p>
6767
* Note that Spark doesn't support the number of partitions on {@link UnspecifiedDistribution},
6868
* the query will fail if the number of partitions are provided but the distribution is
69-
* unspecified.
69+
* unspecified. Data sources may either request a particular number of partitions or
70+
* a preferred partition size via {@link #advisoryPartitionSizeInBytes}, not both.
7071
*
7172
* @return the required number of partitions, any value less than 1 mean no requirement.
7273
*/
7374
default int requiredNumPartitions() { return 0; }
7475

76+
/**
77+
* Returns the advisory (not guaranteed) shuffle partition size in bytes for this write.
78+
* <p>
79+
* Implementations may override this to indicate the preferable partition size in shuffles
80+
* performed to satisfy the requested distribution. Note that Spark doesn't support setting
81+
* the advisory partition size for {@link UnspecifiedDistribution}, the query will fail if
82+
* the advisory partition size is set but the distribution is unspecified. Data sources may
83+
* either request a particular number of partitions via {@link #requiredNumPartitions()} or
84+
* a preferred partition size, not both.
85+
* <p>
86+
* Data sources should be careful with large advisory sizes as it will impact the writing
87+
* parallelism and may degrade the overall job performance.
88+
* <p>
89+
* Note this value only acts like a guidance and Spark does not guarantee the actual and advisory
90+
* shuffle partition sizes will match. Ignored if the adaptive execution is disabled.
91+
*
92+
* @return the advisory partition size, any value less than 1 means no preference.
93+
*/
94+
default long advisoryPartitionSizeInBytes() { return 0; }
95+
7596
/**
7697
* Returns the ordering required by this write.
7798
* <p>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1743,7 +1743,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
17431743
// table `t` even if there is a Project node between the table scan node and Sort node.
17441744
// We also need to propagate the missing attributes from the descendant node to the current
17451745
// node, and project them way at the end via an extra Project.
1746-
case r @ RepartitionByExpression(partitionExprs, child, _)
1746+
case r @ RepartitionByExpression(partitionExprs, child, _, _)
17471747
if !r.resolved || r.missingInput.nonEmpty =>
17481748
val resolvedNoOuter = partitionExprs.map(resolveExpressionByPlanChildren(_, r))
17491749
val (newPartitionExprs, newChild) = resolveExprsAndAddMissingAttrs(resolvedNoOuter, child)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

+5-4
Original file line numberDiff line numberDiff line change
@@ -1202,15 +1202,16 @@ object CollapseRepartition extends Rule[LogicalPlan] {
12021202
}
12031203
// Case 2: When a RepartitionByExpression has a child of global Sort, Repartition or
12041204
// RepartitionByExpression we can remove the child.
1205-
case r @ RepartitionByExpression(_, child @ (Sort(_, true, _) | _: RepartitionOperation), _) =>
1205+
case r @ RepartitionByExpression(
1206+
_, child @ (Sort(_, true, _) | _: RepartitionOperation), _, _) =>
12061207
r.withNewChildren(child.children)
12071208
// Case 3: When a RebalancePartitions has a child of local or global Sort, Repartition or
12081209
// RepartitionByExpression we can remove the child.
1209-
case r @ RebalancePartitions(_, child @ (_: Sort | _: RepartitionOperation), _) =>
1210+
case r @ RebalancePartitions(_, child @ (_: Sort | _: RepartitionOperation), _, _) =>
12101211
r.withNewChildren(child.children)
12111212
// Case 4: When a RebalancePartitions has a child of RebalancePartitions we can remove the
12121213
// child.
1213-
case r @ RebalancePartitions(_, child: RebalancePartitions, _) =>
1214+
case r @ RebalancePartitions(_, child: RebalancePartitions, _, _) =>
12141215
r.withNewChildren(child.children)
12151216
}
12161217
}
@@ -1222,7 +1223,7 @@ object CollapseRepartition extends Rule[LogicalPlan] {
12221223
object OptimizeRepartition extends Rule[LogicalPlan] {
12231224
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
12241225
_.containsPattern(REPARTITION_OPERATION), ruleId) {
1225-
case r @ RepartitionByExpression(partitionExpressions, _, numPartitions)
1226+
case r @ RepartitionByExpression(partitionExpressions, _, numPartitions, _)
12261227
if partitionExpressions.nonEmpty && partitionExpressions.forall(_.foldable) &&
12271228
numPartitions.isEmpty =>
12281229
r.copy(optNumPartitions = Some(1))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

+12-2
Original file line numberDiff line numberDiff line change
@@ -1790,6 +1790,8 @@ trait HasPartitionExpressions extends SQLConfHelper {
17901790

17911791
def optNumPartitions: Option[Int]
17921792

1793+
def optAdvisoryPartitionSize: Option[Long]
1794+
17931795
protected def partitioning: Partitioning = if (partitionExpressions.isEmpty) {
17941796
RoundRobinPartitioning(numPartitions)
17951797
} else {
@@ -1820,7 +1822,11 @@ trait HasPartitionExpressions extends SQLConfHelper {
18201822
case class RepartitionByExpression(
18211823
partitionExpressions: Seq[Expression],
18221824
child: LogicalPlan,
1823-
optNumPartitions: Option[Int]) extends RepartitionOperation with HasPartitionExpressions {
1825+
optNumPartitions: Option[Int],
1826+
optAdvisoryPartitionSize: Option[Long] = None)
1827+
extends RepartitionOperation with HasPartitionExpressions {
1828+
1829+
require(optNumPartitions.isEmpty || optAdvisoryPartitionSize.isEmpty)
18241830

18251831
override val partitioning: Partitioning = {
18261832
if (numPartitions == 1) {
@@ -1857,7 +1863,11 @@ object RepartitionByExpression {
18571863
case class RebalancePartitions(
18581864
partitionExpressions: Seq[Expression],
18591865
child: LogicalPlan,
1860-
optNumPartitions: Option[Int] = None) extends UnaryNode with HasPartitionExpressions {
1866+
optNumPartitions: Option[Int] = None,
1867+
optAdvisoryPartitionSize: Option[Long] = None) extends UnaryNode with HasPartitionExpressions {
1868+
1869+
require(optNumPartitions.isEmpty || optAdvisoryPartitionSize.isEmpty)
1870+
18611871
override def maxRows: Option[Long] = child.maxRows
18621872
override def output: Seq[Attribute] = child.output
18631873
override val nodePatterns: Seq[TreePattern] = Seq(REBALANCE_PARTITIONS)

sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala

+13-1
Original file line numberDiff line numberDiff line change
@@ -1803,7 +1803,19 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase {
18031803

18041804
def numberOfPartitionsNotAllowedWithUnspecifiedDistributionError(): Throwable = {
18051805
new AnalysisException(
1806-
errorClass = "_LEGACY_ERROR_TEMP_1178",
1806+
errorClass = "INVALID_WRITE_DISTRIBUTION.PARTITION_NUM_WITH_UNSPECIFIED_DISTRIBUTION",
1807+
messageParameters = Map.empty)
1808+
}
1809+
1810+
def partitionSizeNotAllowedWithUnspecifiedDistributionError(): Throwable = {
1811+
new AnalysisException(
1812+
errorClass = "INVALID_WRITE_DISTRIBUTION.PARTITION_SIZE_WITH_UNSPECIFIED_DISTRIBUTION",
1813+
messageParameters = Map.empty)
1814+
}
1815+
1816+
def numberAndSizeOfPartitionsNotAllowedTogether(): Throwable = {
1817+
new AnalysisException(
1818+
errorClass = "INVALID_WRITE_DISTRIBUTION.PARTITION_NUM_AND_SIZE",
18071819
messageParameters = Map.empty)
18081820
}
18091821

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala

+5
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ abstract class InMemoryBaseTable(
5555
val distribution: Distribution = Distributions.unspecified(),
5656
val ordering: Array[SortOrder] = Array.empty,
5757
val numPartitions: Option[Int] = None,
58+
val advisoryPartitionSize: Option[Long] = None,
5859
val isDistributionStrictlyRequired: Boolean = true,
5960
val numRowsPerSplit: Int = Int.MaxValue)
6061
extends Table with SupportsRead with SupportsWrite with SupportsMetadataColumns {
@@ -450,6 +451,10 @@ abstract class InMemoryBaseTable(
450451
numPartitions.getOrElse(0)
451452
}
452453

454+
override def advisoryPartitionSizeInBytes(): Long = {
455+
advisoryPartitionSize.getOrElse(0)
456+
}
457+
453458
override def toBatch: BatchWrite = writer
454459

455460
override def toStreaming: StreamingWrite = streamingWriter match {

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,12 @@ class InMemoryTable(
3939
distribution: Distribution = Distributions.unspecified(),
4040
ordering: Array[SortOrder] = Array.empty,
4141
numPartitions: Option[Int] = None,
42+
advisoryPartitionSize: Option[Long] = None,
4243
isDistributionStrictlyRequired: Boolean = true,
4344
override val numRowsPerSplit: Int = Int.MaxValue)
4445
extends InMemoryBaseTable(name, schema, partitioning, properties, distribution,
45-
ordering, numPartitions, isDistributionStrictlyRequired, numRowsPerSplit) with SupportsDelete {
46+
ordering, numPartitions, advisoryPartitionSize, isDistributionStrictlyRequired,
47+
numRowsPerSplit) with SupportsDelete {
4648

4749
override def canDeleteWhere(filters: Array[Filter]): Boolean = {
4850
InMemoryTable.supportsFilters(filters)

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class BasicInMemoryTableCatalog extends TableCatalog {
9191
partitions: Array[Transform],
9292
properties: util.Map[String, String]): Table = {
9393
createTable(ident, schema, partitions, properties, Distributions.unspecified(),
94-
Array.empty, None)
94+
Array.empty, None, None)
9595
}
9696

9797
override def createTable(
@@ -111,6 +111,7 @@ class BasicInMemoryTableCatalog extends TableCatalog {
111111
distribution: Distribution,
112112
ordering: Array[SortOrder],
113113
requiredNumPartitions: Option[Int],
114+
advisoryPartitionSize: Option[Long],
114115
distributionStrictlyRequired: Boolean = true,
115116
numRowsPerSplit: Int = Int.MaxValue): Table = {
116117
if (tables.containsKey(ident)) {
@@ -121,7 +122,8 @@ class BasicInMemoryTableCatalog extends TableCatalog {
121122

122123
val tableName = s"$name.${ident.quoted}"
123124
val table = new InMemoryTable(tableName, schema, partitions, properties, distribution,
124-
ordering, requiredNumPartitions, distributionStrictlyRequired, numRowsPerSplit)
125+
ordering, requiredNumPartitions, advisoryPartitionSize, distributionStrictlyRequired,
126+
numRowsPerSplit)
125127
tables.put(ident, table)
126128
namespaces.putIfAbsent(ident.namespace.toList, Map())
127129
table

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

+6-2
Original file line numberDiff line numberDiff line change
@@ -893,14 +893,18 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
893893
} else {
894894
REPARTITION_BY_NUM
895895
}
896-
exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child), shuffleOrigin) :: Nil
896+
exchange.ShuffleExchangeExec(
897+
r.partitioning, planLater(r.child),
898+
shuffleOrigin, r.optAdvisoryPartitionSize) :: Nil
897899
case r: logical.RebalancePartitions =>
898900
val shuffleOrigin = if (r.partitionExpressions.isEmpty) {
899901
REBALANCE_PARTITIONS_BY_NONE
900902
} else {
901903
REBALANCE_PARTITIONS_BY_COL
902904
}
903-
exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child), shuffleOrigin) :: Nil
905+
exchange.ShuffleExchangeExec(
906+
r.partitioning, planLater(r.child),
907+
shuffleOrigin, r.optAdvisoryPartitionSize) :: Nil
904908
case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
905909
case r: LogicalRDD =>
906910
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ object AQEUtils {
3030
// Project/Filter/LocalSort/CollectMetrics.
3131
// Note: we only care about `HashPartitioning` as `EnsureRequirements` can only optimize out
3232
// user-specified repartition with `HashPartitioning`.
33-
case ShuffleExchangeExec(h: HashPartitioning, _, shuffleOrigin)
33+
case ShuffleExchangeExec(h: HashPartitioning, _, shuffleOrigin, _)
3434
if shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM =>
3535
val numPartitions = if (shuffleOrigin == REPARTITION_BY_NUM) {
3636
Some(h.numPartitions)

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala

+20-10
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,6 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe
6464
1
6565
}
6666
}
67-
val advisoryTargetSize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES)
68-
val minPartitionSize = if (Utils.isTesting) {
69-
// In the tests, we usually set the target size to a very small value that is even smaller
70-
// than the default value of the min partition size. Here we also adjust the min partition
71-
// size to be not larger than 20% of the target size, so that the tests don't need to set
72-
// both configs all the time to check the coalescing behavior.
73-
conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_SIZE).min(advisoryTargetSize / 5)
74-
} else {
75-
conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_SIZE)
76-
}
7767

7868
// Sub-plans under the Union operator can be coalesced independently, so we can divide them
7969
// into independent "coalesce groups", and all shuffle stages within each group have to be
@@ -100,6 +90,17 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe
10090
val specsMap = mutable.HashMap.empty[Int, Seq[ShufflePartitionSpec]]
10191
// Coalesce partitions for each coalesce group independently.
10292
coalesceGroups.zip(minNumPartitionsByGroup).foreach { case (shuffleStages, minNumPartitions) =>
93+
val advisoryTargetSize = advisoryPartitionSize(shuffleStages)
94+
val minPartitionSize = if (Utils.isTesting) {
95+
// In the tests, we usually set the target size to a very small value that is even smaller
96+
// than the default value of the min partition size. Here we also adjust the min partition
97+
// size to be not larger than 20% of the target size, so that the tests don't need to set
98+
// both configs all the time to check the coalescing behavior.
99+
conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_SIZE).min(advisoryTargetSize / 5)
100+
} else {
101+
conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_SIZE)
102+
}
103+
103104
val newPartitionSpecs = ShufflePartitionsUtil.coalescePartitions(
104105
shuffleStages.map(_.shuffleStage.mapStats),
105106
shuffleStages.map(_.partitionSpecs),
@@ -121,6 +122,15 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe
121122
}
122123
}
123124

125+
private def advisoryPartitionSize(shuffleStages: Seq[ShuffleStageInfo]): Long = {
126+
val advisorySizes = shuffleStages.flatMap(_.shuffleStage.advisoryPartitionSize).toSet
127+
if (advisorySizes.size == 1) {
128+
advisorySizes.head
129+
} else {
130+
conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES)
131+
}
132+
}
133+
124134
/**
125135
* Gather all coalesce-able groups such that the shuffle stages in each child of a Union operator
126136
* are in their independent groups if:

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule {
6969
}
7070

7171
private def tryOptimizeSkewedPartitions(shuffle: ShuffleQueryStageExec): SparkPlan = {
72-
val advisorySize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES)
72+
val defaultAdvisorySize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES)
73+
val advisorySize = shuffle.advisoryPartitionSize.getOrElse(defaultAdvisorySize)
7374
val mapStats = shuffle.mapStats
7475
if (mapStats.isEmpty ||
7576
mapStats.get.bytesByPartitionId.forall(_ <= advisorySize)) {

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala

+2
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ case class ShuffleQueryStageExec(
180180
throw new IllegalStateException(s"wrong plan for shuffle stage:\n ${plan.treeString}")
181181
}
182182

183+
@transient val advisoryPartitionSize: Option[Long] = shuffle.advisoryPartitionSize
184+
183185
@transient private lazy val shuffleFuture = shuffle.submitShuffleJob
184186

185187
override protected def doMaterialize(): Future[Any] = shuffleFuture

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala

+11-2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ object DistributionAndOrderingUtils {
3636
funCatalogOpt: Option[FunctionCatalog]): LogicalPlan = write match {
3737
case write: RequiresDistributionAndOrdering =>
3838
val numPartitions = write.requiredNumPartitions()
39+
val partitionSize = write.advisoryPartitionSizeInBytes()
3940

4041
val distribution = write.requiredDistribution match {
4142
case d: OrderedDistribution =>
@@ -49,17 +50,25 @@ object DistributionAndOrderingUtils {
4950

5051
val queryWithDistribution = if (distribution.nonEmpty) {
5152
val optNumPartitions = if (numPartitions > 0) Some(numPartitions) else None
53+
val optPartitionSize = if (partitionSize > 0) Some(partitionSize) else None
54+
55+
if (optNumPartitions.isDefined && optPartitionSize.isDefined) {
56+
throw QueryCompilationErrors.numberAndSizeOfPartitionsNotAllowedTogether()
57+
}
58+
5259
// the conversion to catalyst expressions above produces SortOrder expressions
5360
// for OrderedDistribution and generic expressions for ClusteredDistribution
5461
// this allows RebalancePartitions/RepartitionByExpression to pick either
5562
// range or hash partitioning
5663
if (write.distributionStrictlyRequired()) {
57-
RepartitionByExpression(distribution, query, optNumPartitions)
64+
RepartitionByExpression(distribution, query, optNumPartitions, optPartitionSize)
5865
} else {
59-
RebalancePartitions(distribution, query, optNumPartitions)
66+
RebalancePartitions(distribution, query, optNumPartitions, optPartitionSize)
6067
}
6168
} else if (numPartitions > 0) {
6269
throw QueryCompilationErrors.numberOfPartitionsNotAllowedWithUnspecifiedDistributionError()
70+
} else if (partitionSize > 0) {
71+
throw QueryCompilationErrors.partitionSizeNotAllowedWithUnspecifiedDistributionError()
6372
} else {
6473
query
6574
}

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ case class EnsureRequirements(
188188
}
189189

190190
child match {
191-
case ShuffleExchangeExec(_, c, so) => ShuffleExchangeExec(newPartitioning, c, so)
191+
case ShuffleExchangeExec(_, c, so, ps) =>
192+
ShuffleExchangeExec(newPartitioning, c, so, ps)
192193
case _ => ShuffleExchangeExec(newPartitioning, child)
193194
}
194195
}
@@ -578,7 +579,7 @@ case class EnsureRequirements(
578579

579580
def apply(plan: SparkPlan): SparkPlan = {
580581
val newPlan = plan.transformUp {
581-
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin)
582+
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin, _)
582583
if optimizeOutRepartition &&
583584
(shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM) =>
584585
def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = {

0 commit comments

Comments
 (0)