Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
sunchao committed Aug 31, 2023
1 parent 723a0aa commit e3f3e0b
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,14 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
* @param numPartitions the number of partitions
* @param partitionValues the values for the cluster keys of the distribution, must be
* in ascending order.
* @param originalPartitionValues the original input partition values before any grouping has been
* applied, must be in ascending order.
*/
case class KeyGroupedPartitioning(
expressions: Seq[Expression],
numPartitions: Int,
partitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning {
partitionValues: Seq[InternalRow] = Seq.empty,
originalPartitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning {

override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
Expand Down Expand Up @@ -368,7 +371,7 @@ object KeyGroupedPartitioning {
def apply(
expressions: Seq[Expression],
partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues)
KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues, partitionValues)
}

def supportsExpressions(expressions: Seq[Expression]): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Par
import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.internal.SQLConf

/**
* Physical plan node for scanning a batch of data from a data source v2.
Expand Down Expand Up @@ -101,7 +100,7 @@ case class BatchScanExec(
"partition values that are not present in the original partitioning.")
}

groupPartitions(newPartitions).get.map(_._2)
groupPartitions(newPartitions).get.groupedParts.map(_.parts)

case _ =>
// no validation is needed as the data source did not report any specific partitioning
Expand Down Expand Up @@ -137,81 +136,63 @@ case class BatchScanExec(

outputPartitioning match {
case p: KeyGroupedPartitioning =>
if (conf.v2BucketingPushPartValuesEnabled &&
conf.v2BucketingPartiallyClusteredDistributionEnabled) {
assert(filteredPartitions.forall(_.size == 1),
"Expect partitions to be not grouped when " +
s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
"is enabled")

val groupedPartitions = groupPartitions(finalPartitions.map(_.head),
groupSplits = true).get

// This means the input partitions are not grouped by partition values. We'll need to
// check `groupByPartitionValues` and decide whether to group and replicate splits
// within a partition.
if (spjParams.commonPartitionValues.isDefined &&
spjParams.applyPartialClustering) {
// A mapping from the common partition values to how many splits the partition
// should contain.
val commonPartValuesMap = spjParams.commonPartitionValues
val groupedPartitions = filteredPartitions.map(splits => {
assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey])
(splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits)
})

// This means the input partitions are not grouped by partition values. We'll need to
// check `groupByPartitionValues` and decide whether to group and replicate splits
// within a partition.
if (spjParams.commonPartitionValues.isDefined && spjParams.applyPartialClustering) {
// A mapping from the common partition values to how many splits the partition
// should contain. Note this no longer maintain the partition key ordering.
val commonPartValuesMap = spjParams.commonPartitionValues
.get
.map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2))
.toMap
val nestGroupedPartitions = groupedPartitions.map {
case (partValue, splits) =>
// `commonPartValuesMap` should contain the part value since it's the super set.
val numSplits = commonPartValuesMap
.get(InternalRowComparableWrapper(partValue, p.expressions))
assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
"common partition values from Spark plan")

val newSplits = if (spjParams.replicatePartitions) {
// We need to also replicate partitions according to the other side of join
Seq.fill(numSplits.get)(splits)
} else {
// Not grouping by partition values: this could be the side with partially
// clustered distribution. Because of dynamic filtering, we'll need to check if
// the final number of splits of a partition is smaller than the original
// number, and fill with empty splits if so. This is necessary so that both
// sides of a join will have the same number of partitions & splits.
splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
}
(InternalRowComparableWrapper(partValue, p.expressions), newSplits)
val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) =>
// `commonPartValuesMap` should contain the part value since it's the super set.
val numSplits = commonPartValuesMap
.get(InternalRowComparableWrapper(partValue, p.expressions))
assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
"common partition values from Spark plan")

val newSplits = if (spjParams.replicatePartitions) {
// We need to also replicate partitions according to the other side of join
Seq.fill(numSplits.get)(splits)
} else {
// Not grouping by partition values: this could be the side with partially
// clustered distribution. Because of dynamic filtering, we'll need to check if
// the final number of splits of a partition is smaller than the original
// number, and fill with empty splits if so. This is necessary so that both
// sides of a join will have the same number of partitions & splits.
splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
}
(InternalRowComparableWrapper(partValue, p.expressions), newSplits)
}

// Now fill missing partition keys with empty partitions
val partitionMapping = nestGroupedPartitions.toMap
finalPartitions = spjParams.commonPartitionValues.get.flatMap {
case (partValue, numSplits) =>
// Use empty partition for those partition values that are not present.
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions),
Seq.fill(numSplits)(Seq.empty))
}
} else {
// either `commonPartitionValues` is not defined, or it is defined but
// `applyPartialClustering` is false.
val partitionMapping = groupedPartitions.map { case (row, parts) =>
InternalRowComparableWrapper(row, p.expressions) -> parts
}.toMap

// In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there
// could exist duplicated partition values, as partition grouping is not done
// at the beginning and postponed to this method. It is important to use unique
// partition values here so that grouped partitions won't get duplicated.
finalPartitions = p.uniquePartitionValues.map { partValue =>
// Use empty partition for those partition values that are not present
// Now fill missing partition keys with empty partitions
val partitionMapping = nestGroupedPartitions.toMap
finalPartitions = spjParams.commonPartitionValues.get.flatMap {
case (partValue, numSplits) =>
// Use empty partition for those partition values that are not present.
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
}
InternalRowComparableWrapper(partValue, p.expressions),
Seq.fill(numSplits)(Seq.empty))
}
} else {
val partitionMapping = finalPartitions.map { parts =>
val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey()
InternalRowComparableWrapper(row, p.expressions) -> parts
// either `commonPartitionValues` is not defined, or it is defined but
// `applyPartialClustering` is false.
val partitionMapping = groupedPartitions.map { case (partValue, splits) =>
InternalRowComparableWrapper(partValue, p.expressions) -> splits
}.toMap
finalPartitions = p.partitionValues.map { partValue =>

// In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there
// could exist duplicated partition values, as partition grouping is not done
// at the beginning and postponed to this method. It is important to use unique
// partition values here so that grouped partitions won't get duplicated.
finalPartitions = p.uniquePartitionValues.map { partValue =>
// Use empty partition for those partition values that are not present
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
redact(result)
}

def partitions: Seq[Seq[InputPartition]] =
groupedPartitions.map(_.map(_._2)).getOrElse(inputPartitions.map(Seq(_)))
def partitions: Seq[Seq[InputPartition]] = {
groupedPartitions.map(_.groupedParts.map(_.parts)).getOrElse(inputPartitions.map(Seq(_)))
}

/**
* Shorthand for calling redact() without specifying redacting rules
Expand Down Expand Up @@ -94,16 +95,18 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
keyGroupedPartitioning match {
case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) =>
groupedPartitions
.map { partitionValues =>
KeyGroupedPartitioning(exprs, partitionValues.size, partitionValues.map(_._1))
.map { keyGroupedPartsInfo =>
val keyGroupedParts = keyGroupedPartsInfo.groupedParts
KeyGroupedPartitioning(exprs, keyGroupedParts.size, keyGroupedParts.map(_.value),
keyGroupedPartsInfo.originalParts.map(_.partitionKey()))
}
.getOrElse(super.outputPartitioning)
case _ =>
super.outputPartitioning
}
}

@transient lazy val groupedPartitions: Option[Seq[(InternalRow, Seq[InputPartition])]] = {
@transient lazy val groupedPartitions: Option[KeyGroupedPartitionInfo] = {
// Early check if we actually need to materialize the input partitions.
keyGroupedPartitioning match {
case Some(_) => groupPartitions(inputPartitions)
Expand All @@ -117,24 +120,21 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
* - all input partitions implement [[HasPartitionKey]]
* - `keyGroupedPartitioning` is set
*
* The result, if defined, is a list of tuples where the first element is a partition value,
* and the second element is a list of input partitions that share the same partition value.
* The result, if defined, is a [[KeyGroupedPartitionInfo]] which contains a list of
* [[KeyGroupedPartition]], as well as a list of partition values from the original input splits,
* sorted according to the partition keys in ascending order.
*
* A non-empty result means each partition is clustered on a single key and therefore eligible
* for further optimizations to eliminate shuffling in some operations such as join and aggregate.
*/
def groupPartitions(
inputPartitions: Seq[InputPartition],
groupSplits: Boolean = !conf.v2BucketingPushPartValuesEnabled ||
!conf.v2BucketingPartiallyClusteredDistributionEnabled):
Option[Seq[(InternalRow, Seq[InputPartition])]] = {

def groupPartitions(inputPartitions: Seq[InputPartition]): Option[KeyGroupedPartitionInfo] = {
if (!SQLConf.get.v2BucketingEnabled) return None

keyGroupedPartitioning.flatMap { expressions =>
val results = inputPartitions.takeWhile {
case _: HasPartitionKey => true
case _ => false
}.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p))
}.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p.asInstanceOf[HasPartitionKey]))

if (results.length != inputPartitions.length || inputPartitions.isEmpty) {
// Not all of the `InputPartitions` implements `HasPartitionKey`, therefore skip here.
Expand All @@ -143,32 +143,25 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
// also sort the input partitions according to their partition key order. This ensures
// a canonical order from both sides of a bucketed join, for example.
val partitionDataTypes = expressions.map(_.dataType)
val partitionOrdering: Ordering[(InternalRow, Seq[InputPartition])] = {
val partitionOrdering: Ordering[(InternalRow, InputPartition)] = {
RowOrdering.createNaturalAscendingOrdering(partitionDataTypes).on(_._1)
}

val partitions = if (groupSplits) {
// Group the splits by their partition value
results
val sortedKeyToPartitions = results.sorted(partitionOrdering)
val groupedPartitions = sortedKeyToPartitions
.map(t => (InternalRowComparableWrapper(t._1, expressions), t._2))
.groupBy(_._1)
.toSeq
.map {
case (key, s) => (key.row, s.map(_._2))
}
} else {
// No splits grouping, each split will become a separate Spark partition
results.map(t => (t._1, Seq(t._2)))
}
.map { case (key, s) => KeyGroupedPartition(key.row, s.map(_._2)) }

Some(partitions.sorted(partitionOrdering))
Some(KeyGroupedPartitionInfo(groupedPartitions, sortedKeyToPartitions.map(_._2)))
}
}
}

override def outputOrdering: Seq[SortOrder] = {
// when multiple partitions are grouped together, ordering inside partitions is not preserved
val partitioningPreservesOrdering = groupedPartitions.forall(_.forall(_._2.length <= 1))
val partitioningPreservesOrdering = groupedPartitions
.forall(_.groupedParts.forall(_.parts.length <= 1))
ordering.filter(_ => partitioningPreservesOrdering).getOrElse(super.outputOrdering)
}

Expand Down Expand Up @@ -217,3 +210,19 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
}
}
}

/**
* A key-grouped Spark partition, which could consist of multiple input splits
*
* @param value the partition value shared by all the input splits
* @param parts the input splits that are grouped into a single Spark partition
*/
private[v2] case class KeyGroupedPartition(value: InternalRow, parts: Seq[InputPartition])

/**
* Information about key-grouped partitions, which contains a list of grouped partitions as well
* as the original input partitions before the grouping.
*/
private[v2] case class KeyGroupedPartitionInfo(
groupedParts: Seq[KeyGroupedPartition],
originalParts: Seq[HasPartitionKey])
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,12 @@ case class EnsureRequirements(
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, leftPartitioning, None))
case (Some(KeyGroupedPartitioning(clustering, _, _)), _) =>
case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, None, rightPartitioning))
case (_, Some(KeyGroupedPartitioning(clustering, _, _))) =>
case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys)
.orElse(reorderJoinKeysRecursively(
Expand Down Expand Up @@ -483,7 +483,10 @@ case class EnsureRequirements(
s"'$joinType'. Skipping partially clustered distribution.")
replicateRightSide = false
} else {
val partValues = if (replicateLeftSide) rightPartValues else leftPartValues
// In partially clustered distribution, we should use un-grouped partition values
val spec = if (replicateLeftSide) rightSpec else leftSpec
val partValues = spec.partitioning.originalPartitionValues

val numExpectedPartitions = partValues
.map(InternalRowComparableWrapper(_, partitionExprs))
.groupBy(identity)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ abstract class DistributionAndOrderingSuiteBase
plan: QueryPlan[T]): Partitioning = partitioning match {
case HashPartitioning(exprs, numPartitions) =>
HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions)
case KeyGroupedPartitioning(clustering, numPartitions, partitionValues) =>
KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions,
partitionValues)
case KeyGroupedPartitioning(clustering, numPartitions, partValues, originalPartValues) =>
KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, partValues,
originalPartValues)
case PartitioningCollection(partitionings) =>
PartitioningCollection(partitionings.map(resolvePartitioning(_, plan)))
case RangePartitioning(ordering, numPartitions) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
// Has exactly one partition.
val partitionValues = Seq(31).map(v => InternalRow.fromSeq(Seq(v)))
checkQueryPlan(df, distribution,
physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues))
physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues, partitionValues))
}

test("non-clustered distribution: no V2 catalog") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ class EnsureRequirementsSuite extends SharedSparkSession {
EnsureRequirements.apply(smjExec) match {
case ShuffledHashJoinExec(_, _, _, _, _,
DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _),
ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv),
ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _),
DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) =>
assert(left.expressions == a1 :: Nil)
assert(attrs == a1 :: Nil)
Expand Down

0 comments on commit e3f3e0b

Please sign in to comment.