Skip to content

Commit 30ac539

Browse files
committed
improve test cases and documentation
1 parent 9dc0ca0 commit 30ac539

File tree

3 files changed

+199
-57
lines changed

3 files changed

+199
-57
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,12 @@ case class AnalyzeColumnCommand(
5959

6060
def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = {
6161
val (rowCount, columnStats) = computeColStats(sparkSession, relation)
62+
// We also update table-level stats in order to keep them consistent with column-level stats.
6263
val statistics = Statistics(
6364
sizeInBytes = newTotalSize,
6465
rowCount = Some(rowCount),
65-
colStats = columnStats ++ catalogTable.stats.map(_.colStats).getOrElse(Map()))
66+
// Newly computed column stats should override the existing ones.
67+
colStats = catalogTable.stats.map(_.colStats).getOrElse(Map()) ++ columnStats)
6668
sessionState.catalog.alterTable(catalogTable.copy(stats = Some(statistics)))
6769
// Refresh the cached data source table in the catalog.
6870
sessionState.catalog.refreshTable(tableIdentWithDB)
@@ -90,8 +92,9 @@ case class AnalyzeColumnCommand(
9092
}
9193
}
9294
if (duplicatedColumns.nonEmpty) {
93-
logWarning(s"Duplicated columns ${duplicatedColumns.mkString("(", ", ", ")")} detected " +
94-
s"when analyzing columns ${columnNames.mkString("(", ", ", ")")}, ignoring them.")
95+
logWarning("Duplicate column names were deduplicated in `ANALYZE TABLE` statement. " +
96+
s"Input columns: ${columnNames.mkString("(", ", ", ")")}. " +
97+
s"Duplicate columns: ${duplicatedColumns.mkString("(", ", ", ")")}.")
9598
}
9699

97100
// Collect statistics per column.
@@ -116,42 +119,44 @@ case class AnalyzeColumnCommand(
116119
}
117120

118121
object ColumnStatStruct {
119-
val zero = Literal(0, LongType)
120-
val one = Literal(1, LongType)
122+
private val zero = Literal(0, LongType)
123+
private val one = Literal(1, LongType)
121124

122-
def numNulls(e: Expression): Expression = if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero
123-
def max(e: Expression): Expression = Max(e)
124-
def min(e: Expression): Expression = Min(e)
125-
def ndv(e: Expression, relativeSD: Double): Expression = {
125+
private def numNulls(e: Expression): Expression = {
126+
if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero
127+
}
128+
private def max(e: Expression): Expression = Max(e)
129+
private def min(e: Expression): Expression = Min(e)
130+
private def ndv(e: Expression, relativeSD: Double): Expression = {
126131
// the approximate ndv should never be larger than the number of rows
127132
Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one)))
128133
}
129-
def avgLength(e: Expression): Expression = Average(Length(e))
130-
def maxLength(e: Expression): Expression = Max(Length(e))
131-
def numTrues(e: Expression): Expression = Sum(If(e, one, zero))
132-
def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero))
134+
private def avgLength(e: Expression): Expression = Average(Length(e))
135+
private def maxLength(e: Expression): Expression = Max(Length(e))
136+
private def numTrues(e: Expression): Expression = Sum(If(e, one, zero))
137+
private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero))
133138

134-
def getStruct(exprs: Seq[Expression]): CreateStruct = {
139+
private def getStruct(exprs: Seq[Expression]): CreateStruct = {
135140
CreateStruct(exprs.map { expr: Expression =>
136141
expr.transformUp {
137142
case af: AggregateFunction => af.toAggregateExpression()
138143
}
139144
})
140145
}
141146

142-
def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
147+
private def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
143148
Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD))
144149
}
145150

146-
def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
151+
private def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
147152
Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD))
148153
}
149154

150-
def binaryColumnStat(e: Expression): Seq[Expression] = {
155+
private def binaryColumnStat(e: Expression): Seq[Expression] = {
151156
Seq(numNulls(e), avgLength(e), maxLength(e))
152157
}
153158

154-
def booleanColumnStat(e: Expression): Seq[Expression] = {
159+
private def booleanColumnStat(e: Expression): Seq[Expression] = {
155160
Seq(numNulls(e), numTrues(e), numFalses(e))
156161
}
157162

@@ -162,14 +167,14 @@ object ColumnStatStruct {
162167
}
163168
}
164169

165-
def apply(e: Attribute, relativeSD: Double): CreateStruct = e.dataType match {
170+
def apply(attr: Attribute, relativeSD: Double): CreateStruct = attr.dataType match {
166171
// Use aggregate functions to compute statistics we need.
167-
case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(e, relativeSD))
168-
case StringType => getStruct(stringColumnStat(e, relativeSD))
169-
case BinaryType => getStruct(binaryColumnStat(e))
170-
case BooleanType => getStruct(booleanColumnStat(e))
172+
case _: NumericType | TimestampType | DateType => getStruct(numericColumnStat(attr, relativeSD))
173+
case StringType => getStruct(stringColumnStat(attr, relativeSD))
174+
case BinaryType => getStruct(binaryColumnStat(attr))
175+
case BooleanType => getStruct(booleanColumnStat(attr))
171176
case otherType =>
172177
throw new AnalysisException("Analyzing columns is not supported for column " +
173-
s"${e.name} of data type: ${e.dataType}.")
178+
s"${attr.name} of data type: ${attr.dataType}.")
174179
}
175180
}

sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,8 @@ object SQLConf {
578578
val NDV_MAX_ERROR =
579579
SQLConfigBuilder("spark.sql.statistics.ndv.maxError")
580580
.internal()
581-
.doc("The maximum estimation error allowed in HyperLogLog++ algorithm.")
581+
.doc("The maximum estimation error allowed in HyperLogLog++ algorithm when generating " +
582+
"column level statistics.")
582583
.doubleConf
583584
.createWithDefault(0.05)
584585

sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala

Lines changed: 168 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.io.{File, PrintWriter}
2121

2222
import scala.reflect.ClassTag
2323

24-
import org.apache.spark.sql.{AnalysisException, QueryTest, Row, StatisticsTest}
24+
import org.apache.spark.sql._
2525
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
2626
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
2727
import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils}
@@ -358,53 +358,189 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils
358358
}
359359
}
360360

361-
test("generate column-level statistics and load them from hive metastore") {
361+
private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean): (Statistics, Statistics) = {
362+
val tableName = "tbl"
363+
var statsBeforeUpdate: Statistics = null
364+
var statsAfterUpdate: Statistics = null
365+
withTable(tableName) {
366+
val tableIndent = TableIdentifier(tableName, Some("default"))
367+
val catalog = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog]
368+
sql(s"CREATE TABLE $tableName (key int) USING PARQUET")
369+
sql(s"INSERT INTO $tableName SELECT 1")
370+
if (isAnalyzeColumns) {
371+
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key")
372+
} else {
373+
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
374+
}
375+
// Table lookup will make the table cached.
376+
catalog.lookupRelation(tableIndent)
377+
statsBeforeUpdate = catalog.getCachedDataSourceTable(tableIndent)
378+
.asInstanceOf[LogicalRelation].catalogTable.get.stats.get
379+
380+
sql(s"INSERT INTO $tableName SELECT 2")
381+
if (isAnalyzeColumns) {
382+
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key")
383+
} else {
384+
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
385+
}
386+
catalog.lookupRelation(tableIndent)
387+
statsAfterUpdate = catalog.getCachedDataSourceTable(tableIndent)
388+
.asInstanceOf[LogicalRelation].catalogTable.get.stats.get
389+
}
390+
(statsBeforeUpdate, statsAfterUpdate)
391+
}
392+
393+
test("test refreshing table stats of cached data source table by `ANALYZE TABLE` statement") {
394+
val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = false)
395+
396+
assert(statsBeforeUpdate.sizeInBytes > 0)
397+
assert(statsBeforeUpdate.rowCount.contains(1))
398+
399+
assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes)
400+
assert(statsAfterUpdate.rowCount.contains(2))
401+
}
402+
403+
test("test refreshing column stats of cached data source table by `ANALYZE TABLE` statement") {
404+
val (statsBeforeUpdate, statsAfterUpdate) = getStatsBeforeAfterUpdate(isAnalyzeColumns = true)
405+
406+
assert(statsBeforeUpdate.sizeInBytes > 0)
407+
assert(statsBeforeUpdate.rowCount.contains(1))
408+
StatisticsTest.checkColStat(
409+
dataType = IntegerType,
410+
colStat = statsBeforeUpdate.colStats("key"),
411+
expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
412+
rsd = spark.sessionState.conf.ndvMaxError)
413+
414+
assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes)
415+
assert(statsAfterUpdate.rowCount.contains(2))
416+
StatisticsTest.checkColStat(
417+
dataType = IntegerType,
418+
colStat = statsAfterUpdate.colStats("key"),
419+
expectedColStat = ColumnStat(InternalRow(0L, 2, 1, 2L)),
420+
rsd = spark.sessionState.conf.ndvMaxError)
421+
}
422+
423+
private def dataAndColStats(): (DataFrame, Seq[(StructField, ColumnStat)]) = {
362424
import testImplicits._
363425

364426
val intSeq = Seq(1, 2)
365427
val stringSeq = Seq("a", "bb")
428+
val binarySeq = Seq("a", "bb").map(_.getBytes)
366429
val booleanSeq = Seq(true, false)
367-
368430
val data = intSeq.indices.map { i =>
369-
(intSeq(i), stringSeq(i), booleanSeq(i))
431+
(intSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i))
370432
}
371-
val tableName = "table"
372-
withTable(tableName) {
373-
val df = data.toDF("c1", "c2", "c3")
374-
df.write.format("parquet").saveAsTable(tableName)
375-
val expectedColStatsSeq = df.schema.map { f =>
376-
val colStat = f.dataType match {
377-
case IntegerType =>
378-
ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong))
379-
case StringType =>
380-
ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble,
381-
stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong))
382-
case BooleanType =>
383-
ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong,
384-
booleanSeq.count(_.equals(false)).toLong))
385-
}
386-
(f, colStat)
433+
val df = data.toDF("c1", "c2", "c3", "c4")
434+
val expectedColStatsSeq = df.schema.map { f =>
435+
val colStat = f.dataType match {
436+
case IntegerType =>
437+
ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, intSeq.distinct.length.toLong))
438+
case StringType =>
439+
ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / stringSeq.length.toDouble,
440+
stringSeq.map(_.length).max.toInt, stringSeq.distinct.length.toLong))
441+
case BinaryType =>
442+
ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / binarySeq.length.toDouble,
443+
binarySeq.map(_.length).max.toInt))
444+
case BooleanType =>
445+
ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong,
446+
booleanSeq.count(_.equals(false)).toLong))
387447
}
448+
(f, colStat)
449+
}
450+
(df, expectedColStatsSeq)
451+
}
452+
453+
private def checkColStats(
454+
tableName: String,
455+
isDataSourceTable: Boolean,
456+
expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = {
457+
val readback = spark.table(tableName)
458+
val stats = readback.queryExecution.analyzed.collect {
459+
case rel: MetastoreRelation =>
460+
assert(!isDataSourceTable, "Expected a Hive serde table, but got a data source table")
461+
rel.catalogTable.stats.get
462+
case rel: LogicalRelation =>
463+
assert(isDataSourceTable, "Expected a data source table, but got a Hive serde table")
464+
rel.catalogTable.get.stats.get
465+
}
466+
assert(stats.length == 1)
467+
val columnStats = stats.head.colStats
468+
assert(columnStats.size == expectedColStatsSeq.length)
469+
expectedColStatsSeq.foreach { case (field, expectedColStat) =>
470+
StatisticsTest.checkColStat(
471+
dataType = field.dataType,
472+
colStat = columnStats(field.name),
473+
expectedColStat = expectedColStat,
474+
rsd = spark.sessionState.conf.ndvMaxError)
475+
}
476+
}
388477

389-
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1, c2, c3")
390-
val readback = spark.table(tableName)
391-
val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation =>
392-
val columnStats = rel.catalogTable.get.stats.get.colStats
393-
expectedColStatsSeq.foreach { case (field, expectedColStat) =>
394-
assert(columnStats.contains(field.name))
395-
val colStat = columnStats(field.name)
478+
test("generate and load column-level stats for data source table") {
479+
val dsTable = "dsTable"
480+
withTable(dsTable) {
481+
val (df, expectedColStatsSeq) = dataAndColStats()
482+
df.write.format("parquet").saveAsTable(dsTable)
483+
sql(s"ANALYZE TABLE $dsTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4")
484+
checkColStats(dsTable, isDataSourceTable = true, expectedColStatsSeq)
485+
}
486+
}
487+
488+
test("generate and load column-level stats for hive serde table") {
489+
val hTable = "hTable"
490+
val tmp = "tmp"
491+
withTable(hTable, tmp) {
492+
val (df, expectedColStatsSeq) = dataAndColStats()
493+
df.write.format("parquet").saveAsTable(tmp)
494+
sql(s"CREATE TABLE $hTable (c1 int, c2 string, c3 binary, c4 boolean) STORED AS TEXTFILE")
495+
sql(s"INSERT INTO $hTable SELECT * FROM $tmp")
496+
sql(s"ANALYZE TABLE $hTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, c4")
497+
checkColStats(hTable, isDataSourceTable = false, expectedColStatsSeq)
498+
}
499+
}
500+
501+
// When caseSensitive is on, for columns with only case difference, they are different columns
502+
// and we should generate column stats for all of them.
503+
private def checkCaseSensitiveColStats(columnName: String): Unit = {
504+
val tableName = "tbl"
505+
withTable(tableName) {
506+
val column1 = columnName.toLowerCase
507+
val column2 = columnName.toUpperCase
508+
withSQLConf("spark.sql.caseSensitive" -> "true") {
509+
sql(s"CREATE TABLE $tableName (`$column1` int, `$column2` double) USING PARQUET")
510+
sql(s"INSERT INTO $tableName SELECT 1, 3.0")
511+
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS `$column1`, `$column2`")
512+
val readback = spark.table(tableName)
513+
val relations = readback.queryExecution.analyzed.collect { case rel: LogicalRelation =>
514+
val columnStats = rel.catalogTable.get.stats.get.colStats
515+
assert(columnStats.size == 2)
516+
StatisticsTest.checkColStat(
517+
dataType = IntegerType,
518+
colStat = columnStats(column1),
519+
expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
520+
rsd = spark.sessionState.conf.ndvMaxError)
396521
StatisticsTest.checkColStat(
397-
dataType = field.dataType,
398-
colStat = colStat,
399-
expectedColStat = expectedColStat,
522+
dataType = DoubleType,
523+
colStat = columnStats(column2),
524+
expectedColStat = ColumnStat(InternalRow(0L, 3.0d, 3.0d, 1L)),
400525
rsd = spark.sessionState.conf.ndvMaxError)
526+
rel
401527
}
402-
rel
528+
assert(relations.size == 1)
403529
}
404-
assert(relations.size == 1)
405530
}
406531
}
407532

533+
test("check column statistics for case sensitive column names") {
534+
checkCaseSensitiveColStats(columnName = "c1")
535+
}
536+
537+
test("check column statistics for case sensitive non-ascii column names") {
538+
// scalastyle:off
539+
// non ascii characters are not allowed in the source code, so we disable the scalastyle.
540+
checkCaseSensitiveColStats(columnName = "列c")
541+
// scalastyle:on
542+
}
543+
408544
test("estimates the size of a test MetastoreRelation") {
409545
val df = sql("""SELECT * FROM src""")
410546
val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation =>

0 commit comments

Comments
 (0)