Skip to content

Commit

Permalink
update row count
Browse files Browse the repository at this point in the history
  • Loading branch information
sunchao committed Nov 1, 2023
1 parent 7ae76c1 commit 6887c7a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@

package org.apache.spark.sql.execution.command

import org.apache.spark.sql.{Column, Row, SparkSession}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, ExternalCatalogUtils}
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.util.PartitioningUtils
import org.apache.spark.util.collection.Utils

/**
* Analyzes a given set of partitions to generate per-partition statistics, which will be used in
Expand Down Expand Up @@ -101,7 +98,7 @@ case class AnalyzePartitionCommand(
if (noscan) {
Map.empty
} else {
calculateRowCountsPerPartition(sparkSession, tableMeta, partitionValueSpec)
CommandUtils.calculateRowCountsPerPartition(sparkSession, tableMeta, partitionValueSpec)
}

// Update the metastore if newly computed statistics are different from those
Expand All @@ -122,35 +119,5 @@ case class AnalyzePartitionCommand(
Seq.empty[Row]
}

private def calculateRowCountsPerPartition(
sparkSession: SparkSession,
tableMeta: CatalogTable,
partitionValueSpec: Option[TablePartitionSpec]): Map[TablePartitionSpec, BigInt] = {
val filter = if (partitionValueSpec.isDefined) {
val filters = partitionValueSpec.get.map {
case (columnName, value) => EqualTo(UnresolvedAttribute(columnName), Literal(value))
}
filters.reduce(And)
} else {
Literal.TrueLiteral
}

val tableDf = sparkSession.table(tableMeta.identifier)
val partitionColumns = tableMeta.partitionColumnNames.map(Column(_))

val df = tableDf.filter(Column(filter)).groupBy(partitionColumns: _*).count()

df.collect().map { r =>
val partitionColumnValues = partitionColumns.indices.map { i =>
if (r.isNullAt(i)) {
ExternalCatalogUtils.DEFAULT_PARTITION_NAME
} else {
r.get(i).toString
}
}
val spec = Utils.toMap(tableMeta.partitionColumnNames, partitionColumnValues)
val count = BigInt(r.getLong(partitionColumns.size))
(spec, count)
}.toMap
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ import scala.util.control.NonFatal
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{Column, SparkSession}
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition, CatalogTableType}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition, CatalogTableType, ExternalCatalogUtils}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
Expand All @@ -37,6 +39,7 @@ import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, InMemoryFileIndex}
import org.apache.spark.sql.internal.{SessionState, SQLConf}
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.Utils

/**
* For the purpose of calculating total directory sizes, use this filter to
Expand Down Expand Up @@ -76,7 +79,9 @@ object CommandUtils extends Logging {

def calculateTotalSize(
spark: SparkSession,
catalogTable: CatalogTable): (BigInt, Seq[CatalogTablePartition]) = {
catalogTable: CatalogTable,
partitionRowCount: Option[Map[TablePartitionSpec, BigInt]] = None):
(BigInt, Seq[CatalogTablePartition]) = {
val sessionState = spark.sessionState
val startTime = System.nanoTime()
val (totalSize, newPartitions) = if (catalogTable.partitionColumnNames.isEmpty) {
Expand All @@ -89,7 +94,8 @@ object CommandUtils extends Logging {
val paths = partitions.map(_.storage.locationUri)
val sizes = calculateMultipleLocationSizes(spark, catalogTable.identifier, paths)
val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) =>
val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), None)
val newRowCount = partitionRowCount.flatMap(_.get(p.spec))
val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), newRowCount)
newStats.map(_ => p.copy(stats = newStats))
}
(sizes.sum, newPartitions)
Expand Down Expand Up @@ -231,7 +237,15 @@ object CommandUtils extends Logging {
}
} else {
// Compute stats for the whole table
val (newTotalSize, newPartitions) = CommandUtils.calculateTotalSize(sparkSession, tableMeta)
val rowCounts: Map[TablePartitionSpec, BigInt] =
if (noScan) {
Map.empty
} else {
calculateRowCountsPerPartition(sparkSession, tableMeta, None)
}
val (newTotalSize, newPartitions) = CommandUtils.calculateTotalSize(
sparkSession, tableMeta, Some(rowCounts))

val newRowCount =
if (noScan) None else Some(BigInt(sparkSession.table(tableIdentWithDB).count()))

Expand Down Expand Up @@ -444,4 +458,36 @@ object CommandUtils extends Logging {
case NonFatal(e) => logWarning(s"Exception when attempting to uncache $name", e)
}
}

def calculateRowCountsPerPartition(
sparkSession: SparkSession,
tableMeta: CatalogTable,
partitionValueSpec: Option[TablePartitionSpec]): Map[TablePartitionSpec, BigInt] = {
val filter = if (partitionValueSpec.isDefined) {
val filters = partitionValueSpec.get.map {
case (columnName, value) => EqualTo(UnresolvedAttribute(columnName), Literal(value))
}
filters.reduce(And)
} else {
Literal.TrueLiteral
}

val tableDf = sparkSession.table(tableMeta.identifier)
val partitionColumns = tableMeta.partitionColumnNames.map(Column(_))

val df = tableDf.filter(Column(filter)).groupBy(partitionColumns: _*).count()

df.collect().map { r =>
val partitionColumnValues = partitionColumns.indices.map { i =>
if (r.isNullAt(i)) {
ExternalCatalogUtils.DEFAULT_PARTITION_NAME
} else {
r.get(i).toString
}
}
val spec = Utils.toMap(tableMeta.partitionColumnNames, partitionColumnValues)
val count = BigInt(r.getLong(partitionColumns.size))
(spec, count)
}.toMap
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,9 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
partitionDates.foreach { ds =>
val partStats = queryStats(ds)
assert(partStats.nonEmpty)
// The scan option doesn't update partition row count, only size in bytes.
// The scan option should update partition row count
assert(partStats.get.sizeInBytes == 4411)
assert(partStats.get.rowCount.isEmpty)
assert(partStats.get.rowCount.get == 25)
}
}
}
Expand Down

0 comments on commit 6887c7a

Please sign in to comment.