Skip to content

Commit

Permalink
remove duplicated code
Browse files Browse the repository at this point in the history
  • Loading branch information
sunchao committed Nov 1, 2023
1 parent 6887c7a commit b16f725
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,8 @@ case class AnalyzePartitionCommand(

// Update the metastore if newly computed statistics are different from those
// recorded in the metastore.

val sizes = CommandUtils.calculateMultipleLocationSizes(sparkSession, tableMeta.identifier,
partitions.map(_.storage.locationUri))
val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) =>
val newRowCount = rowCounts.get(p.spec)
val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), newRowCount)
newStats.map(_ => p.copy(stats = newStats))
}

val (_, newPartitions) = CommandUtils.calculatePartitionStats(
sparkSession, tableMeta, partitions, Some(rowCounts))
if (newPartitions.nonEmpty) {
sessionState.catalog.alterPartitions(tableMeta.identifier, newPartitions)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,31 @@ object CommandUtils extends Logging {
// Calculate table size as a sum of the visible partitions. See SPARK-21079
val partitions = sessionState.catalog.listPartitions(catalogTable.identifier)
logInfo(s"Starting to calculate sizes for ${partitions.length} partitions.")
val paths = partitions.map(_.storage.locationUri)
val sizes = calculateMultipleLocationSizes(spark, catalogTable.identifier, paths)
val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) =>
val newRowCount = partitionRowCount.flatMap(_.get(p.spec))
val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), newRowCount)
newStats.map(_ => p.copy(stats = newStats))
}
val (sizes, newPartitions) = calculatePartitionStats(spark, catalogTable, partitions,
partitionRowCount)
(sizes.sum, newPartitions)
}
logInfo(s"It took ${(System.nanoTime() - startTime) / (1000 * 1000)} ms to calculate" +
s" the total size for table ${catalogTable.identifier}.")
(totalSize, newPartitions)
}

def calculatePartitionStats(
spark: SparkSession,
catalogTable: CatalogTable,
partitions: Seq[CatalogTablePartition],
partitionRowCount: Option[Map[TablePartitionSpec, BigInt]] = None):
(Seq[Long], Seq[CatalogTablePartition]) = {
val paths = partitions.map(_.storage.locationUri)
val sizes = calculateMultipleLocationSizes(spark, catalogTable.identifier, paths)
val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) =>
val newRowCount = partitionRowCount.flatMap(_.get(p.spec))
val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), newRowCount)
newStats.map(_ => p.copy(stats = newStats))
}
(sizes, newPartitions)
}

def calculateSingleLocationSize(
sessionState: SessionState,
identifier: TableIdentifier,
Expand Down

0 comments on commit b16f725

Please sign in to comment.