From e3f3e0bac99327c5cb67d8bac8fe610708ac339f Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Mon, 28 Aug 2023 14:32:24 -0700 Subject: [PATCH] initial commit --- .../plans/physical/partitioning.scala | 7 +- .../datasources/v2/BatchScanExec.scala | 117 ++++++++---------- .../v2/DataSourceV2ScanExecBase.scala | 65 +++++----- .../exchange/EnsureRequirements.scala | 9 +- .../DistributionAndOrderingSuiteBase.scala | 6 +- .../KeyGroupedPartitioningSuite.scala | 2 +- .../exchange/EnsureRequirementsSuite.scala | 2 +- 7 files changed, 102 insertions(+), 106 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index ce557422a087a..3d897ec4af7f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -327,11 +327,14 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * @param numPartitions the number of partitions * @param partitionValues the values for the cluster keys of the distribution, must be * in ascending order. + * @param originalPartitionValues the original input partition values before any grouping has been + * applied, must be in ascending order. */ case class KeyGroupedPartitioning( expressions: Seq[Expression], numPartitions: Int, - partitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning { + partitionValues: Seq[InternalRow] = Seq.empty, + originalPartitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning { override def satisfies0(required: Distribution): Boolean = { super.satisfies0(required) || { @@ -368,7 +371,7 @@ object KeyGroupedPartitioning { def apply( expressions: Seq[Expression], partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { - KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues) + KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues, partitionValues) } def supportsExpressions(expressions: Seq[Expression]): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index cc674961f8eb5..c274006972d50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Par import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read._ -import org.apache.spark.sql.internal.SQLConf /** * Physical plan node for scanning a batch of data from a data source v2. @@ -101,7 +100,7 @@ case class BatchScanExec( "partition values that are not present in the original partitioning.") } - groupPartitions(newPartitions).get.map(_._2) + groupPartitions(newPartitions).get.groupedParts.map(_.parts) case _ => // no validation is needed as the data source did not report any specific partitioning @@ -137,81 +136,63 @@ case class BatchScanExec( outputPartitioning match { case p: KeyGroupedPartitioning => - if (conf.v2BucketingPushPartValuesEnabled && - conf.v2BucketingPartiallyClusteredDistributionEnabled) { - assert(filteredPartitions.forall(_.size == 1), - "Expect partitions to be not grouped when " + - s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + - "is enabled") - - val groupedPartitions = groupPartitions(finalPartitions.map(_.head), - groupSplits = true).get - - // This means the input partitions are not grouped by partition values. We'll need to - // check `groupByPartitionValues` and decide whether to group and replicate splits - // within a partition. - if (spjParams.commonPartitionValues.isDefined && - spjParams.applyPartialClustering) { - // A mapping from the common partition values to how many splits the partition - // should contain. - val commonPartValuesMap = spjParams.commonPartitionValues + val groupedPartitions = filteredPartitions.map(splits => { + assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) + (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) + }) + + // This means the input partitions are not grouped by partition values. We'll need to + // check `groupByPartitionValues` and decide whether to group and replicate splits + // within a partition. + if (spjParams.commonPartitionValues.isDefined && spjParams.applyPartialClustering) { + // A mapping from the common partition values to how many splits the partition + // should contain. Note this no longer maintain the partition key ordering. + val commonPartValuesMap = spjParams.commonPartitionValues .get .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) .toMap - val nestGroupedPartitions = groupedPartitions.map { - case (partValue, splits) => - // `commonPartValuesMap` should contain the part value since it's the super set. - val numSplits = commonPartValuesMap - .get(InternalRowComparableWrapper(partValue, p.expressions)) - assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + - "common partition values from Spark plan") - - val newSplits = if (spjParams.replicatePartitions) { - // We need to also replicate partitions according to the other side of join - Seq.fill(numSplits.get)(splits) - } else { - // Not grouping by partition values: this could be the side with partially - // clustered distribution. Because of dynamic filtering, we'll need to check if - // the final number of splits of a partition is smaller than the original - // number, and fill with empty splits if so. This is necessary so that both - // sides of a join will have the same number of partitions & splits. - splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) - } - (InternalRowComparableWrapper(partValue, p.expressions), newSplits) + val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) => + // `commonPartValuesMap` should contain the part value since it's the super set. + val numSplits = commonPartValuesMap + .get(InternalRowComparableWrapper(partValue, p.expressions)) + assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + + "common partition values from Spark plan") + + val newSplits = if (spjParams.replicatePartitions) { + // We need to also replicate partitions according to the other side of join + Seq.fill(numSplits.get)(splits) + } else { + // Not grouping by partition values: this could be the side with partially + // clustered distribution. Because of dynamic filtering, we'll need to check if + // the final number of splits of a partition is smaller than the original + // number, and fill with empty splits if so. This is necessary so that both + // sides of a join will have the same number of partitions & splits. + splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) } + (InternalRowComparableWrapper(partValue, p.expressions), newSplits) + } - // Now fill missing partition keys with empty partitions - val partitionMapping = nestGroupedPartitions.toMap - finalPartitions = spjParams.commonPartitionValues.get.flatMap { - case (partValue, numSplits) => - // Use empty partition for those partition values that are not present. - partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), - Seq.fill(numSplits)(Seq.empty)) - } - } else { - // either `commonPartitionValues` is not defined, or it is defined but - // `applyPartialClustering` is false. - val partitionMapping = groupedPartitions.map { case (row, parts) => - InternalRowComparableWrapper(row, p.expressions) -> parts - }.toMap - - // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there - // could exist duplicated partition values, as partition grouping is not done - // at the beginning and postponed to this method. It is important to use unique - // partition values here so that grouped partitions won't get duplicated. - finalPartitions = p.uniquePartitionValues.map { partValue => - // Use empty partition for those partition values that are not present + // Now fill missing partition keys with empty partitions + val partitionMapping = nestGroupedPartitions.toMap + finalPartitions = spjParams.commonPartitionValues.get.flatMap { + case (partValue, numSplits) => + // Use empty partition for those partition values that are not present. partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) - } + InternalRowComparableWrapper(partValue, p.expressions), + Seq.fill(numSplits)(Seq.empty)) } } else { - val partitionMapping = finalPartitions.map { parts => - val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey() - InternalRowComparableWrapper(row, p.expressions) -> parts + // either `commonPartitionValues` is not defined, or it is defined but + // `applyPartialClustering` is false. + val partitionMapping = groupedPartitions.map { case (partValue, splits) => + InternalRowComparableWrapper(partValue, p.expressions) -> splits }.toMap - finalPartitions = p.partitionValues.map { partValue => + + // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there + // could exist duplicated partition values, as partition grouping is not done + // at the beginning and postponed to this method. It is important to use unique + // partition values here so that grouped partitions won't get duplicated. + finalPartitions = p.uniquePartitionValues.map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index f688d3514d9aa..94667fbd00c18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -62,8 +62,9 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { redact(result) } - def partitions: Seq[Seq[InputPartition]] = - groupedPartitions.map(_.map(_._2)).getOrElse(inputPartitions.map(Seq(_))) + def partitions: Seq[Seq[InputPartition]] = { + groupedPartitions.map(_.groupedParts.map(_.parts)).getOrElse(inputPartitions.map(Seq(_))) + } /** * Shorthand for calling redact() without specifying redacting rules @@ -94,8 +95,10 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { keyGroupedPartitioning match { case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) => groupedPartitions - .map { partitionValues => - KeyGroupedPartitioning(exprs, partitionValues.size, partitionValues.map(_._1)) + .map { keyGroupedPartsInfo => + val keyGroupedParts = keyGroupedPartsInfo.groupedParts + KeyGroupedPartitioning(exprs, keyGroupedParts.size, keyGroupedParts.map(_.value), + keyGroupedPartsInfo.originalParts.map(_.partitionKey())) } .getOrElse(super.outputPartitioning) case _ => @@ -103,7 +106,7 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { } } - @transient lazy val groupedPartitions: Option[Seq[(InternalRow, Seq[InputPartition])]] = { + @transient lazy val groupedPartitions: Option[KeyGroupedPartitionInfo] = { // Early check if we actually need to materialize the input partitions. keyGroupedPartitioning match { case Some(_) => groupPartitions(inputPartitions) @@ -117,24 +120,21 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { * - all input partitions implement [[HasPartitionKey]] * - `keyGroupedPartitioning` is set * - * The result, if defined, is a list of tuples where the first element is a partition value, - * and the second element is a list of input partitions that share the same partition value. + * The result, if defined, is a [[KeyGroupedPartitionInfo]] which contains a list of + * [[KeyGroupedPartition]], as well as a list of partition values from the original input splits, + * sorted according to the partition keys in ascending order. * * A non-empty result means each partition is clustered on a single key and therefore eligible * for further optimizations to eliminate shuffling in some operations such as join and aggregate. */ - def groupPartitions( - inputPartitions: Seq[InputPartition], - groupSplits: Boolean = !conf.v2BucketingPushPartValuesEnabled || - !conf.v2BucketingPartiallyClusteredDistributionEnabled): - Option[Seq[(InternalRow, Seq[InputPartition])]] = { - + def groupPartitions(inputPartitions: Seq[InputPartition]): Option[KeyGroupedPartitionInfo] = { if (!SQLConf.get.v2BucketingEnabled) return None + keyGroupedPartitioning.flatMap { expressions => val results = inputPartitions.takeWhile { case _: HasPartitionKey => true case _ => false - }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p)) + }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p.asInstanceOf[HasPartitionKey])) if (results.length != inputPartitions.length || inputPartitions.isEmpty) { // Not all of the `InputPartitions` implements `HasPartitionKey`, therefore skip here. @@ -143,32 +143,25 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { // also sort the input partitions according to their partition key order. This ensures // a canonical order from both sides of a bucketed join, for example. val partitionDataTypes = expressions.map(_.dataType) - val partitionOrdering: Ordering[(InternalRow, Seq[InputPartition])] = { + val partitionOrdering: Ordering[(InternalRow, InputPartition)] = { RowOrdering.createNaturalAscendingOrdering(partitionDataTypes).on(_._1) } - - val partitions = if (groupSplits) { - // Group the splits by their partition value - results + val sortedKeyToPartitions = results.sorted(partitionOrdering) + val groupedPartitions = sortedKeyToPartitions .map(t => (InternalRowComparableWrapper(t._1, expressions), t._2)) .groupBy(_._1) .toSeq - .map { - case (key, s) => (key.row, s.map(_._2)) - } - } else { - // No splits grouping, each split will become a separate Spark partition - results.map(t => (t._1, Seq(t._2))) - } + .map { case (key, s) => KeyGroupedPartition(key.row, s.map(_._2)) } - Some(partitions.sorted(partitionOrdering)) + Some(KeyGroupedPartitionInfo(groupedPartitions, sortedKeyToPartitions.map(_._2))) } } } override def outputOrdering: Seq[SortOrder] = { // when multiple partitions are grouped together, ordering inside partitions is not preserved - val partitioningPreservesOrdering = groupedPartitions.forall(_.forall(_._2.length <= 1)) + val partitioningPreservesOrdering = groupedPartitions + .forall(_.groupedParts.forall(_.parts.length <= 1)) ordering.filter(_ => partitioningPreservesOrdering).getOrElse(super.outputOrdering) } @@ -217,3 +210,19 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { } } } + +/** + * A key-grouped Spark partition, which could consist of multiple input splits + * + * @param value the partition value shared by all the input splits + * @param parts the input splits that are grouped into a single Spark partition + */ +private[v2] case class KeyGroupedPartition(value: InternalRow, parts: Seq[InputPartition]) + +/** + * Information about key-grouped partitions, which contains a list of grouped partitions as well + * as the original input partitions before the grouping. + */ +private[v2] case class KeyGroupedPartitionInfo( + groupedParts: Seq[KeyGroupedPartition], + originalParts: Seq[HasPartitionKey]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 42c880e7c6262..f8e6fd1d0167f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -288,12 +288,12 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) - case (Some(KeyGroupedPartitioning(clustering, _, _)), _) => + case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) - case (_, Some(KeyGroupedPartitioning(clustering, _, _))) => + case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( @@ -483,7 +483,10 @@ case class EnsureRequirements( s"'$joinType'. Skipping partially clustered distribution.") replicateRightSide = false } else { - val partValues = if (replicateLeftSide) rightPartValues else leftPartValues + // In partially clustered distribution, we should use un-grouped partition values + val spec = if (replicateLeftSide) rightSpec else leftSpec + val partValues = spec.partitioning.originalPartitionValues + val numExpectedPartitions = partValues .map(InternalRowComparableWrapper(_, partitionExprs)) .groupBy(identity) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index f4317e632761c..1a0efa7c4aafb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -51,9 +51,9 @@ abstract class DistributionAndOrderingSuiteBase plan: QueryPlan[T]): Partitioning = partitioning match { case HashPartitioning(exprs, numPartitions) => HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) - case KeyGroupedPartitioning(clustering, numPartitions, partitionValues) => - KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, - partitionValues) + case KeyGroupedPartitioning(clustering, numPartitions, partValues, originalPartValues) => + KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, partValues, + originalPartValues) case PartitioningCollection(partitionings) => PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) case RangePartitioning(ordering, numPartitions) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 5b5e402117384..b22aba61aabd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -131,7 +131,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { // Has exactly one partition. val partitionValues = Seq(31).map(v => InternalRow.fromSeq(Seq(v))) checkQueryPlan(df, distribution, - physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues)) + physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues, partitionValues)) } test("non-clustered distribution: no V2 catalog") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 3c9b92e5f66b6..3b0bb088a1076 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1127,7 +1127,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { EnsureRequirements.apply(smjExec) match { case ShuffledHashJoinExec(_, _, _, _, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), - ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv), + ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _), DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) => assert(left.expressions == a1 :: Nil) assert(attrs == a1 :: Nil)